roms-tools 0.0.6__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/setup/grid.py CHANGED
@@ -1,51 +1,97 @@
1
1
  import copy
2
- from dataclasses import dataclass, field
3
- from datetime import date
2
+ from dataclasses import dataclass, field, asdict
4
3
 
5
4
  import numpy as np
6
5
  import xarray as xr
6
+ import yaml
7
+ import importlib.metadata
7
8
 
8
- from typing import Any
9
+ from roms_tools.setup.topography import _add_topography_and_mask, _add_velocity_masks
10
+ from roms_tools.setup.plot import _plot
11
+ from roms_tools.setup.utils import interpolate_from_rho_to_u, interpolate_from_rho_to_v
9
12
 
13
+ import warnings
10
14
 
11
15
  RADIUS_OF_EARTH = 6371315.0 # in m
12
16
 
13
17
 
14
- # TODO lat_rho and lon_rho should be coordinate variables
15
-
16
18
  # TODO should we store an xgcm.Grid object instead of an xarray Dataset? Or even subclass xgcm.Grid?
17
19
 
18
20
 
19
21
  @dataclass(frozen=True, kw_only=True)
20
22
  class Grid:
21
23
  """
22
- A single ROMS grid.
23
-
24
- Used for creating, plotting, and then saving a new ROMS domain grid.
25
-
26
- Parameters
27
- ----------
28
- nx
29
- Number of grid points in the x-direction
30
- ny
31
- Number of grid points in the y-direction
32
- size_x
33
- Domain size in the x-direction (in km?)
34
- size_y
35
- Domain size in the y-direction (in km?)
36
- center_lon
37
- Longitude of grid center
38
- center_lat
39
- Latitude of grid center
40
- rot
41
- Rotation of grid x-direction from lines of constant latitude.
42
- Measured in degrees, with positive values meaning a counterclockwise rotation.
43
- The default is 0, which means that the x-direction of the grid x-direction is aligned with lines of constant latitude.
44
-
45
- Raises
46
- ------
47
- ValueError
48
- If you try to create a grid which crosses the Greenwich Meridian
24
+ A single ROMS grid.
25
+
26
+ Used for creating, plotting, and then saving a new ROMS domain grid.
27
+
28
+ Parameters
29
+ ----------
30
+ nx : int
31
+ Number of grid points in the x-direction.
32
+ ny : int
33
+ Number of grid points in the y-direction.
34
+ size_x : float
35
+ Domain size in the x-direction (in kilometers).
36
+ size_y : float
37
+ Domain size in the y-direction (in kilometers).
38
+ center_lon : float
39
+ Longitude of grid center.
40
+ center_lat : float
41
+ Latitude of grid center.
42
+ rot : float, optional
43
+ Rotation of grid x-direction from lines of constant latitude, measured in degrees.
44
+ Positive values represent a counterclockwise rotation.
45
+ The default is 0, which means that the x-direction of the grid is aligned with lines of constant latitude.
46
+ topography_source : str, optional
47
+ Specifies the data source to use for the topography. Options are
48
+ "ETOPO5". The default is "ETOPO5".
49
+ smooth_factor : float, optional
50
+ The smoothing factor used in the domain-wide Gaussian smoothing of the
51
+ topography. Smaller values result in less smoothing, while larger
52
+ values produce more smoothing. The default is 8.
53
+ hmin : float, optional
54
+ The minimum ocean depth (in meters). The default is 5.
55
+ rmax : float, optional
56
+ The maximum slope parameter (in meters). This parameter controls
57
+ the local smoothing of the topography. Smaller values result in
58
+ smoother topography, while larger values preserve more detail.
59
+ The default is 0.2.
60
+
61
+ Attributes
62
+ ----------
63
+ nx : int
64
+ Number of grid points in the x-direction.
65
+ ny : int
66
+ Number of grid points in the y-direction.
67
+ size_x : float
68
+ Domain size in the x-direction (in kilometers).
69
+ size_y : float
70
+ Domain size in the y-direction (in kilometers).
71
+ center_lon : float
72
+ Longitude of grid center.
73
+ center_lat : float
74
+ Latitude of grid center.
75
+ rot : float
76
+ Rotation of grid x-direction from lines of constant latitude.
77
+ topography_source : str
78
+ Data source used for the topography.
79
+ smooth_factor : int
80
+ Smoothing factor used in the domain-wide Gaussian smoothing of the topography.
81
+ hmin : float
82
+ Minimum ocean depth (in meters).
83
+ rmax : float
84
+ Maximum slope parameter (in meters).
85
+ ds : xr.Dataset
86
+ The xarray Dataset containing the grid data.
87
+ straddle : bool
88
+ Indicates if the Greenwich meridian (0° longitude) intersects the domain.
89
+ `True` if it does, `False` otherwise.
90
+
91
+ Raises
92
+ ------
93
+ ValueError
94
+ If you try to create a grid with domain size larger than 20000 km.
49
95
  """
50
96
 
51
97
  nx: int
@@ -55,7 +101,12 @@ class Grid:
55
101
  center_lon: float
56
102
  center_lat: float
57
103
  rot: float = 0
104
+ topography_source: str = "ETOPO5"
105
+ smooth_factor: int = 8
106
+ hmin: float = 5.0
107
+ rmax: float = 0.2
58
108
  ds: xr.Dataset = field(init=False, repr=False)
109
+ straddle: bool = field(init=False, repr=False)
59
110
 
60
111
  def __post_init__(self):
61
112
  ds = _make_grid_ds(
@@ -71,6 +122,82 @@ class Grid:
71
122
  # see https://stackoverflow.com/questions/53756788/how-to-set-the-value-of-dataclass-field-in-post-init-when-frozen-true
72
123
  object.__setattr__(self, "ds", ds)
73
124
 
125
+ # Update self.ds with topography and mask information
126
+ self.add_topography_and_mask(
127
+ topography_source=self.topography_source,
128
+ smooth_factor=self.smooth_factor,
129
+ hmin=self.hmin,
130
+ rmax=self.rmax,
131
+ )
132
+
133
+ # Check if the Greenwich meridian goes through the domain.
134
+ self._straddle()
135
+
136
+ def add_topography_and_mask(
137
+ self, topography_source="ETOPO5", smooth_factor=8, hmin=5.0, rmax=0.2
138
+ ) -> None:
139
+ """
140
+ Add topography and mask to the grid dataset.
141
+
142
+ This method processes the topography data and generates a land/sea mask.
143
+ It applies several steps, including interpolating topography, smoothing
144
+ the topography over the entire domain and locally, and filling in enclosed basins. The
145
+ processed topography and mask are added to the grid's dataset as new variables.
146
+
147
+ Parameters
148
+ ----------
149
+ topography_source : str, optional
150
+ Specifies the data source to use for the topography. Options are
151
+ "ETOPO5". The default is "ETOPO5".
152
+ smooth_factor : float, optional
153
+ The smoothing factor used in the domain-wide Gaussian smoothing of the
154
+ topography. Smaller values result in less smoothing, while larger
155
+ values produce more smoothing. The default is 8.
156
+ hmin : float, optional
157
+ The minimum ocean depth (in meters). The default is 5.
158
+ rmax : float, optional
159
+ The maximum slope parameter (in meters). This parameter controls
160
+ the local smoothing of the topography. Smaller values result in
161
+ smoother topography, while larger values preserve more detail.
162
+ The default is 0.2.
163
+
164
+ Returns
165
+ -------
166
+ None
167
+ This method modifies the dataset in place and does not return a value.
168
+ """
169
+
170
+ ds = _add_topography_and_mask(
171
+ self.ds, topography_source, smooth_factor, hmin, rmax
172
+ )
173
+ # Assign the updated dataset back to the frozen dataclass
174
+ object.__setattr__(self, "ds", ds)
175
+
176
+ def compute_bathymetry_laplacian(self):
177
+ """
178
+ Compute the Laplacian of the 'h' field in the provided grid dataset.
179
+
180
+ Adds:
181
+ xarray.DataArray: The Laplacian of the 'h' field as a new variable in the dataset self.ds.
182
+ """
183
+
184
+ # Extract the 'h' field and grid spacing variables
185
+ h = self.ds.h
186
+ pm = self.ds.pm # Reciprocal of grid spacing in x-direction
187
+ pn = self.ds.pn # Reciprocal of grid spacing in y-direction
188
+
189
+ # Compute second derivatives using finite differences
190
+ d2h_dx2 = (h.shift(xi_rho=-1) - 2 * h + h.shift(xi_rho=1)) * pm**2
191
+ d2h_dy2 = (h.shift(eta_rho=-1) - 2 * h + h.shift(eta_rho=1)) * pn**2
192
+
193
+ # Compute the Laplacian by summing second derivatives
194
+ laplacian_h = d2h_dx2 + d2h_dy2
195
+
196
+ # Add the Laplacian as a new variable in the dataset
197
+ self.ds["h_laplacian"] = laplacian_h
198
+ self.ds["h_laplacian"].attrs["long_name"] = "Laplacian of final bathymetry"
199
+ self.ds["h_laplacian"].attrs["units"] = "1/m"
200
+
74
201
  def save(self, filepath: str) -> None:
75
202
  """
76
203
  Save the grid information to a netCDF4 file.
@@ -81,78 +208,251 @@ class Grid:
81
208
  """
82
209
  self.ds.to_netcdf(filepath)
83
210
 
84
- def from_file(self, filepath: str) -> "Grid":
211
+ def to_yaml(self, filepath: str) -> None:
85
212
  """
86
- Open an existing grid from a file.
213
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
87
214
 
88
215
  Parameters
89
216
  ----------
90
- filepath
217
+ filepath : str
218
+ The path to the YAML file where the parameters will be saved.
91
219
  """
92
- # TODO set other parameters that were saved into the file, because every parameter we need gets saved.
93
- # TODO actually we will need to deduce size_x and size_y from the file, that's annoying.
94
- self.ds = xr.open_dataset(filepath)
95
- raise NotImplementedError()
220
+ data = asdict(self)
221
+ data.pop("ds", None)
222
+ data.pop("straddle", None)
223
+
224
+ # Include the version of roms-tools
225
+ try:
226
+ roms_tools_version = importlib.metadata.version("roms-tools")
227
+ except importlib.metadata.PackageNotFoundError:
228
+ roms_tools_version = "unknown"
229
+
230
+ # Create header
231
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
232
+
233
+ # Use the class name as the top-level key
234
+ yaml_data = {self.__class__.__name__: data}
235
+
236
+ with open(filepath, "w") as file:
237
+ # Write header
238
+ file.write(header)
239
+ # Write YAML data
240
+ yaml.dump(yaml_data, file, default_flow_style=False)
241
+
242
+ @classmethod
243
+ def from_file(cls, filepath: str) -> "Grid":
244
+ """
245
+ Create a Grid instance from an existing file.
96
246
 
97
- def to_xgcm() -> Any:
98
- # TODO we could convert the dataset to an xgcm.Grid object and return here?
99
- raise NotImplementedError()
247
+ Parameters
248
+ ----------
249
+ filepath : str
250
+ Path to the file containing the grid information.
251
+
252
+ Returns
253
+ -------
254
+ Grid
255
+ A new instance of Grid populated with data from the file.
256
+ """
257
+ # Load the dataset from the file
258
+ ds = xr.open_dataset(filepath)
259
+
260
+ if not all(mask in ds for mask in ["mask_u", "mask_v"]):
261
+ ds = _add_velocity_masks(ds)
262
+ if not all(coord in ds for coord in ["lat_u", "lon_u", "lat_v", "lon_v"]):
263
+ ds = _add_lat_lon_at_velocity_points(ds)
264
+
265
+ # Create a new Grid instance without calling __init__ and __post_init__
266
+ grid = cls.__new__(cls)
267
+
268
+ # Set the dataset for the grid instance
269
+ object.__setattr__(grid, "ds", ds)
270
+
271
+ # Check if the Greenwich meridian goes through the domain.
272
+ grid._straddle()
273
+
274
+ # Manually set the remaining attributes by extracting parameters from dataset
275
+ object.__setattr__(grid, "nx", ds.sizes["xi_rho"] - 2)
276
+ object.__setattr__(grid, "ny", ds.sizes["eta_rho"] - 2)
277
+ object.__setattr__(grid, "center_lon", ds["tra_lon"].values.item())
278
+ object.__setattr__(grid, "center_lat", ds["tra_lat"].values.item())
279
+ object.__setattr__(grid, "rot", ds["rotate"].values.item())
280
+
281
+ for attr in [
282
+ "size_x",
283
+ "size_y",
284
+ "topography_source",
285
+ "smooth_factor",
286
+ "hmin",
287
+ "rmax",
288
+ ]:
289
+ if attr in ds.attrs:
290
+ object.__setattr__(grid, attr, ds.attrs[attr])
291
+
292
+ return grid
293
+
294
+ @classmethod
295
+ def from_yaml(cls, filepath: str) -> "Grid":
296
+ """
297
+ Create an instance of the class from a YAML file.
298
+
299
+ Parameters
300
+ ----------
301
+ filepath : str
302
+ The path to the YAML file from which the parameters will be read.
303
+
304
+ Returns
305
+ -------
306
+ Grid
307
+ An instance of the Grid class.
308
+ """
309
+ # Read the entire file content
310
+ with open(filepath, "r") as file:
311
+ file_content = file.read()
312
+
313
+ # Split the content into YAML documents
314
+ documents = list(yaml.safe_load_all(file_content))
315
+
316
+ header_data = None
317
+ grid_data = None
318
+
319
+ # Iterate over documents to find the header and grid configuration
320
+ for doc in documents:
321
+ if doc is None:
322
+ continue
323
+ if "roms_tools_version" in doc:
324
+ header_data = doc
325
+ elif "Grid" in doc:
326
+ grid_data = doc["Grid"]
327
+
328
+ if header_data is None:
329
+ raise ValueError("Version of ROMS-Tools not found in the YAML file.")
330
+ else:
331
+ # Check the roms_tools_version
332
+ roms_tools_version_header = header_data.get("roms_tools_version")
333
+ # Get current version of roms-tools
334
+ try:
335
+ roms_tools_version_current = importlib.metadata.version("roms-tools")
336
+ except importlib.metadata.PackageNotFoundError:
337
+ roms_tools_version_current = "unknown"
338
+
339
+ if roms_tools_version_header != roms_tools_version_current:
340
+ warnings.warn(
341
+ f"Current roms-tools version ({roms_tools_version_current}) does not match the version in the YAML header ({roms_tools_version_header}).",
342
+ UserWarning,
343
+ )
344
+
345
+ if grid_data is None:
346
+ raise ValueError("No Grid configuration found in the YAML file.")
347
+
348
+ return cls(**grid_data)
349
+
350
+ # override __repr__ method to only print attributes that are actually set
351
+ def __repr__(self) -> str:
352
+ cls = self.__class__
353
+ cls_name = cls.__name__
354
+ # Create a dictionary of attribute names and values, filtering out those that are not set and 'ds'
355
+ attr_dict = {
356
+ k: v for k, v in self.__dict__.items() if k != "ds" and v is not None
357
+ }
358
+ attr_str = ", ".join(f"{k}={v!r}" for k, v in attr_dict.items())
359
+ return f"{cls_name}({attr_str})"
360
+
361
+ # def to_xgcm() -> Any:
362
+ # # TODO we could convert the dataset to an xgcm.Grid object and return here?
363
+ # raise NotImplementedError()
364
+
365
+ def _straddle(self) -> None:
366
+ """
367
+ Check if the Greenwich meridian goes through the domain.
368
+
369
+ This method sets the `straddle` attribute to `True` if the Greenwich meridian
370
+ (0° longitude) intersects the domain defined by `lon_rho`. Otherwise, it sets
371
+ the `straddle` attribute to `False`.
372
+
373
+ The check is based on whether the longitudinal differences between adjacent
374
+ points exceed 300 degrees, indicating a potential wraparound of longitude.
375
+ """
376
+
377
+ if (
378
+ np.abs(self.ds.lon_rho.diff("xi_rho")).max() > 300
379
+ or np.abs(self.ds.lon_rho.diff("eta_rho")).max() > 300
380
+ ):
381
+ object.__setattr__(self, "straddle", True)
382
+ else:
383
+ object.__setattr__(self, "straddle", False)
100
384
 
101
385
  def plot(self, bathymetry: bool = False) -> None:
102
386
  """
103
387
  Plot the grid.
104
388
 
105
- Requires cartopy and matplotlib.
106
-
107
389
  Parameters
108
390
  ----------
109
- bathymetry: bool
391
+ bathymetry : bool
110
392
  Whether or not to plot the bathymetry. Default is False.
111
- """
112
393
 
113
- # TODO optionally plot topography on top?
114
- if bathymetry:
115
- raise NotImplementedError()
394
+ Returns
395
+ -------
396
+ None
397
+ This method does not return any value. It generates and displays a plot.
116
398
 
117
- import cartopy.crs as ccrs
118
- import matplotlib.pyplot as plt
399
+ """
119
400
 
120
- lon_deg = (self.ds["lon_rho"] - 360).values
121
- lat_deg = self.ds["lat_rho"].values
401
+ if bathymetry:
402
+ kwargs = {"cmap": "YlGnBu"}
403
+
404
+ _plot(
405
+ self.ds,
406
+ field=self.ds.h.where(self.ds.mask_rho),
407
+ straddle=self.straddle,
408
+ kwargs=kwargs,
409
+ )
410
+ else:
411
+ _plot(self.ds, straddle=self.straddle)
412
+
413
+ def coarsen(self):
414
+ """
415
+ Update the grid by adding grid variables that are coarsened versions of the original
416
+ fine-resoluion grid variables. The coarsening is by a factor of two.
417
+
418
+ The specific variables being coarsened are:
419
+ - `lon_rho` -> `lon_coarse`: Longitude at rho points.
420
+ - `lat_rho` -> `lat_coarse`: Latitude at rho points.
421
+ - `h` -> `h_coarse`: Bathymetry (depth).
422
+ - `angle` -> `angle_coarse`: Angle between the xi axis and true east.
423
+ - `mask_rho` -> `mask_coarse`: Land/sea mask at rho points.
424
+
425
+ Returns
426
+ -------
427
+ None
428
+
429
+ Modifies
430
+ --------
431
+ self.ds : xr.Dataset
432
+ The dataset attribute of the Grid instance is updated with the new coarser variables.
433
+ """
434
+ d = {
435
+ "lon_rho": "lon_coarse",
436
+ "lat_rho": "lat_coarse",
437
+ "h": "h_coarse",
438
+ "angle": "angle_coarse",
439
+ "mask_rho": "mask_coarse",
440
+ }
122
441
 
123
- # Define projections
124
- geodetic = ccrs.Geodetic()
125
- trans = ccrs.NearsidePerspective(
126
- central_longitude=np.mean(lon_deg), central_latitude=np.mean(lat_deg)
127
- )
442
+ for fine_var, coarse_var in d.items():
443
+ fine_field = self.ds[fine_var]
444
+ if self.straddle and fine_var == "lon_rho":
445
+ fine_field = xr.where(fine_field > 180, fine_field - 360, fine_field)
128
446
 
129
- # find corners
130
- (lo1, la1) = (lon_deg[0, 0], lat_deg[0, 0])
131
- (lo2, la2) = (lon_deg[0, -1], lat_deg[0, -1])
132
- (lo3, la3) = (lon_deg[-1, -1], lat_deg[-1, -1])
133
- (lo4, la4) = (lon_deg[-1, 0], lat_deg[-1, 0])
134
-
135
- # transform coordinates to projected space
136
- lo1t, la1t = trans.transform_point(lo1, la1, geodetic)
137
- lo2t, la2t = trans.transform_point(lo2, la2, geodetic)
138
- lo3t, la3t = trans.transform_point(lo3, la3, geodetic)
139
- lo4t, la4t = trans.transform_point(lo4, la4, geodetic)
140
-
141
- plt.figure(figsize=(10, 10))
142
- ax = plt.axes(projection=trans)
143
-
144
- ax.plot(
145
- [lo1t, lo2t, lo3t, lo4t, lo1t],
146
- [la1t, la2t, la3t, la4t, la1t],
147
- "ro-",
148
- )
447
+ coarse_field = _f2c(fine_field)
448
+ if fine_var == "lon_rho":
449
+ coarse_field = xr.where(
450
+ coarse_field < 0, coarse_field + 360, coarse_field
451
+ )
149
452
 
150
- ax.coastlines(
151
- resolution="50m", linewidth=0.5, color="black"
152
- ) # add map of coastlines
153
- ax.gridlines()
453
+ self.ds[coarse_var] = coarse_field
154
454
 
155
- plt.show()
455
+ self.ds["mask_coarse"] = xr.where(self.ds["mask_coarse"] > 0.5, 1, 0)
156
456
 
157
457
 
158
458
  def _make_grid_ds(
@@ -164,195 +464,145 @@ def _make_grid_ds(
164
464
  center_lat: float,
165
465
  rot: float,
166
466
  ) -> xr.Dataset:
467
+ _raise_if_domain_size_too_large(size_x, size_y)
167
468
 
168
- if size_y > size_x:
169
- domain_length, domain_width = size_x * 1e3, size_y * 1e3 # in m
170
- nl, nw = nx, ny
171
- else:
172
- domain_length, domain_width = size_y * 1e3, size_x * 1e3 # in m
173
- nl, nw = ny, nx
174
-
175
- initial_lon_lat_vars = _make_initial_lon_lat_ds(domain_length, domain_width, nl, nw)
469
+ initial_lon_lat_vars = _make_initial_lon_lat_ds(size_x, size_y, nx, ny)
176
470
 
471
+ # rotate coordinate system
177
472
  rotated_lon_lat_vars = _rotate(*initial_lon_lat_vars, rot)
473
+ lon, *_ = rotated_lon_lat_vars
178
474
 
179
- lon2, *_ = rotated_lon_lat_vars
180
-
181
- _raise_if_crosses_greenwich_meridian(lon2, center_lon)
182
-
475
+ # translate coordinate system
183
476
  translated_lon_lat_vars = _translate(*rotated_lon_lat_vars, center_lat, center_lon)
477
+ lon, lat, lonu, latu, lonv, latv, lonq, latq = translated_lon_lat_vars
184
478
 
185
- lon4, lat4, lonu, latu, lonv, latv, lone, late = translated_lon_lat_vars
186
-
187
- pm, pn = _compute_coordinate_metrics(lon4, lonu, latu, lonv, latv)
188
-
189
- ang = _compute_angle(lon4, lonu, latu, lone)
190
-
191
- ds = _create_grid_ds(
192
- nx,
193
- ny,
194
- lon4,
195
- lat4,
196
- pm,
197
- pn,
198
- ang,
199
- size_x,
200
- size_y,
201
- rot,
202
- center_lon,
203
- center_lat,
204
- lone,
205
- late,
206
- )
479
+ # compute 1/dx and 1/dy
480
+ pm, pn = _compute_coordinate_metrics(lon, lonu, latu, lonv, latv)
207
481
 
208
- # TODO topography
209
- # ds = _make_topography(ds)
482
+ # compute angle of local grid positive x-axis relative to east
483
+ ang = _compute_angle(lon, lonu, latu, lonq)
210
484
 
211
- ds = _add_global_metadata(ds, nx, ny, size_x, size_y, center_lon, center_lat, rot)
485
+ ds = _create_grid_ds(lon, lat, pm, pn, ang, rot, center_lon, center_lat)
212
486
 
213
- return ds
487
+ ds = _add_global_metadata(ds, size_x, size_y)
214
488
 
489
+ return ds
215
490
 
216
- def _raise_if_crosses_greenwich_meridian(lon, center_lon):
217
- # We have to do this before the grid is translated because we don't trust the grid creation routines in that case.
218
-
219
- # TODO it would be nice to handle this case, but we first need to know what ROMS expects / can handle.
220
491
 
221
- # TODO what about grids which cross the international dateline?
492
+ def _raise_if_domain_size_too_large(size_x, size_y):
493
+ threshold = 20000
494
+ if size_x > threshold or size_y > threshold:
495
+ raise ValueError("Domain size has to be smaller than %g km" % threshold)
222
496
 
223
- if np.min(lon + center_lon) < 0 < np.max(lon + center_lon):
224
- raise ValueError("Grid cannot cross Greenwich Meridian")
225
497
 
498
+ def _make_initial_lon_lat_ds(size_x, size_y, nx, ny):
499
+ # Mercator projection around the equator
226
500
 
227
- def _make_initial_lon_lat_ds(domain_length, domain_width, nl, nw):
501
+ # initially define the domain to be longer in x-direction (dimension "length")
502
+ # than in y-direction (dimension "width") to keep grid distortion minimal
503
+ if size_y > size_x:
504
+ domain_length, domain_width = size_y * 1e3, size_x * 1e3 # in m
505
+ nl, nw = ny, nx
506
+ else:
507
+ domain_length, domain_width = size_x * 1e3, size_y * 1e3 # in m
508
+ nl, nw = nx, ny
228
509
 
229
- domain_length_in_degrees_longitude = domain_length / RADIUS_OF_EARTH
230
- domain_width_in_degrees_latitude = domain_width / RADIUS_OF_EARTH
510
+ domain_length_in_degrees = domain_length / RADIUS_OF_EARTH
511
+ domain_width_in_degrees = domain_width / RADIUS_OF_EARTH
231
512
 
232
- longitude_array_1d_in_degrees = (
233
- domain_length_in_degrees_longitude * np.arange(-0.5, nl + 1.5, 1) / nl
234
- - domain_length_in_degrees_longitude / 2
513
+ # 1d array describing the longitudes at cell centers
514
+ x = np.arange(-0.5, nl + 1.5, 1)
515
+ lon_array_1d_in_degrees = (
516
+ domain_length_in_degrees * x / nl - domain_length_in_degrees / 2
517
+ )
518
+ # 1d array describing the longitudes at cell corners (or vorticity points "q")
519
+ xq = np.arange(-1, nl + 2, 1)
520
+ lonq_array_1d_in_degrees_q = (
521
+ domain_length_in_degrees * xq / nl - domain_length_in_degrees / 2
235
522
  )
236
523
 
237
- # TODO I don't fully understand what this piece of code achieves
238
- mul = 1.0
239
- for it in range(1, 101):
524
+ # convert degrees latitude to y-coordinate using Mercator projection
525
+ y1 = np.log(np.tan(np.pi / 4 - domain_width_in_degrees / 4))
526
+ y2 = np.log(np.tan(np.pi / 4 + domain_width_in_degrees / 4))
240
527
 
241
- # convert degrees latitude to y-coordinate using Mercator projection
242
- y1 = np.log(np.tan(np.pi / 4 - domain_width_in_degrees_latitude / 4))
243
- y2 = np.log(np.tan(np.pi / 4 + domain_width_in_degrees_latitude / 4))
528
+ # linearly space points in y-space
529
+ y = (y2 - y1) * np.arange(-0.5, nw + 1.5, 1) / nw + y1
530
+ yq = (y2 - y1) * np.arange(-1, nw + 2) / nw + y1
244
531
 
245
- # linearly space points in y-space
246
- y = (y2 - y1) * np.arange(-0.5, nw + 1.5, 1) / nw + y1
532
+ # inverse Mercator projections
533
+ lat_array_1d_in_degrees = np.arctan(np.sinh(y))
534
+ latq_array_1d_in_degrees = np.arctan(np.sinh(yq))
247
535
 
248
- # convert back to longitude using inverse Mercator projection
249
- # lat1d = 2*np.arctan(np.exp(y)) - np.pi/2
250
- latitude_array_1d_in_degrees = np.arctan(np.sinh(y))
536
+ # 2d grid at cell centers
537
+ lon, lat = np.meshgrid(lon_array_1d_in_degrees, lat_array_1d_in_degrees)
538
+ # 2d grid at cell corners
539
+ lonq, latq = np.meshgrid(lonq_array_1d_in_degrees_q, latq_array_1d_in_degrees)
251
540
 
252
- # find width and height of new grid at central grid point in degrees
253
- latitude_array_1d_in_degrees_cen = 0.5 * (
254
- latitude_array_1d_in_degrees[int(np.round(nw / 2) + 1)]
255
- - latitude_array_1d_in_degrees[int(np.round(nw / 2) - 1)]
256
- )
257
- longitude_array_1d_in_degrees_cen = domain_length_in_degrees_longitude / nl
258
-
259
- # scale the domain width in degreees latitude somehow?
260
- mul = (
261
- latitude_array_1d_in_degrees_cen
262
- / longitude_array_1d_in_degrees_cen
263
- * domain_length_in_degrees_longitude
264
- / domain_width_in_degrees_latitude
265
- * nw
266
- / nl
267
- )
268
- latitude_array_1d_in_degrees = latitude_array_1d_in_degrees / mul
541
+ if size_y > size_x:
542
+ # Rotate grid by 90 degrees because until here the grid has been defined
543
+ # to be longer in x-direction than in y-direction
269
544
 
270
- # TODO what does the 'e' suffix mean?
271
- lon1de = (
272
- domain_length_in_degrees_longitude * np.arange(-1, nl + 2, 1) / nl
273
- - domain_length_in_degrees_longitude / 2
274
- )
275
- ye = (y2 - y1) * np.arange(-1, nw + 2) / nw + y1
276
- # lat1de = 2 * np.arctan(np.exp(ye)) - np.pi/2
277
- lat1de = np.arctan(np.sinh(ye))
278
- lat1de = lat1de / mul
545
+ lon, lat = _rot_sphere(lon, lat, 90)
546
+ lonq, latq = _rot_sphere(lonq, latq, 90)
279
547
 
280
- lon1, lat1 = np.meshgrid(
281
- longitude_array_1d_in_degrees, latitude_array_1d_in_degrees
282
- )
283
- lone, late = np.meshgrid(lon1de, lat1de)
284
- lonu = 0.5 * (lon1[:, :-1] + lon1[:, 1:])
285
- latu = 0.5 * (lat1[:, :-1] + lat1[:, 1:])
286
- lonv = 0.5 * (lon1[:-1, :] + lon1[1:, :])
287
- latv = 0.5 * (lat1[:-1, :] + lat1[1:, :])
288
-
289
- if domain_length > domain_width:
290
- # Rotate grid 90 degrees so that the width is now longer than the length
291
-
292
- lon1, lat1 = rot_sphere(lon1, lat1, 90)
293
- lonu, latu = rot_sphere(lonu, latu, 90)
294
- lonv, latv = rot_sphere(lonv, latv, 90)
295
- lone, late = rot_sphere(lone, late, 90)
296
-
297
- lon1 = np.transpose(np.flip(lon1, 0))
298
- lat1 = np.transpose(np.flip(lat1, 0))
299
- lone = np.transpose(np.flip(lone, 0))
300
- late = np.transpose(np.flip(late, 1))
301
-
302
- lonu_tmp = np.transpose(np.flip(lonv, 0))
303
- latu_tmp = np.transpose(np.flip(latv, 0))
304
- lonv = np.transpose(np.flip(lonu, 0))
305
- latv = np.transpose(np.flip(latu, 0))
306
- lonu = lonu_tmp
307
- latu = latu_tmp
548
+ lon = np.transpose(np.flip(lon, 0))
549
+ lat = np.transpose(np.flip(lat, 0))
550
+ lonq = np.transpose(np.flip(lonq, 0))
551
+ latq = np.transpose(np.flip(latq, 0))
552
+
553
+ # infer longitudes and latitudes at u- and v-points
554
+ lonu = 0.5 * (lon[:, :-1] + lon[:, 1:])
555
+ latu = 0.5 * (lat[:, :-1] + lat[:, 1:])
556
+ lonv = 0.5 * (lon[:-1, :] + lon[1:, :])
557
+ latv = 0.5 * (lat[:-1, :] + lat[1:, :])
308
558
 
309
559
  # TODO wrap up into temporary container Dataset object?
310
- return lon1, lat1, lonu, latu, lonv, latv, lone, late
560
+ return lon, lat, lonu, latu, lonv, latv, lonq, latq
311
561
 
312
562
 
313
- def _rotate(lon1, lat1, lonu, latu, lonv, latv, lone, late, rot):
563
+ def _rotate(lon, lat, lonu, latu, lonv, latv, lonq, latq, rot):
314
564
  """Rotate grid counterclockwise relative to surface of Earth by rot degrees"""
315
565
 
316
- (lon2, lat2) = rot_sphere(lon1, lat1, rot)
317
- (lonu, latu) = rot_sphere(lonu, latu, rot)
318
- (lonv, latv) = rot_sphere(lonv, latv, rot)
319
- (lone, late) = rot_sphere(lone, late, rot)
566
+ (lon, lat) = _rot_sphere(lon, lat, rot)
567
+ (lonu, latu) = _rot_sphere(lonu, latu, rot)
568
+ (lonv, latv) = _rot_sphere(lonv, latv, rot)
569
+ (lonq, latq) = _rot_sphere(lonq, latq, rot)
320
570
 
321
- return lon2, lat2, lonu, latu, lonv, latv, lone, late
571
+ return lon, lat, lonu, latu, lonv, latv, lonq, latq
322
572
 
323
573
 
324
- def _translate(lon2, lat2, lonu, latu, lonv, latv, lone, late, tra_lat, tra_lon):
574
+ def _translate(lon, lat, lonu, latu, lonv, latv, lonq, latq, tra_lat, tra_lon):
325
575
  """Translate grid so that the centre lies at the position (tra_lat, tra_lon)"""
326
576
 
327
- (lon3, lat3) = tra_sphere(lon2, lat2, tra_lat)
328
- (lonu, latu) = tra_sphere(lonu, latu, tra_lat)
329
- (lonv, latv) = tra_sphere(lonv, latv, tra_lat)
330
- (lone, late) = tra_sphere(lone, late, tra_lat)
577
+ (lon, lat) = _tra_sphere(lon, lat, tra_lat)
578
+ (lonu, latu) = _tra_sphere(lonu, latu, tra_lat)
579
+ (lonv, latv) = _tra_sphere(lonv, latv, tra_lat)
580
+ (lonq, latq) = _tra_sphere(lonq, latq, tra_lat)
331
581
 
332
- lon4 = lon3 + tra_lon * np.pi / 180
582
+ lon = lon + tra_lon * np.pi / 180
333
583
  lonu = lonu + tra_lon * np.pi / 180
334
584
  lonv = lonv + tra_lon * np.pi / 180
335
- lone = lone + tra_lon * np.pi / 180
336
- lon4[lon4 < -np.pi] = lon4[lon4 < -np.pi] + 2 * np.pi
585
+ lonq = lonq + tra_lon * np.pi / 180
586
+
587
+ lon[lon < -np.pi] = lon[lon < -np.pi] + 2 * np.pi
337
588
  lonu[lonu < -np.pi] = lonu[lonu < -np.pi] + 2 * np.pi
338
589
  lonv[lonv < -np.pi] = lonv[lonv < -np.pi] + 2 * np.pi
339
- lone[lone < -np.pi] = lone[lone < -np.pi] + 2 * np.pi
340
- lat4 = lat3
341
-
342
- return lon4, lat4, lonu, latu, lonv, latv, lone, late
590
+ lonq[lonq < -np.pi] = lonq[lonq < -np.pi] + 2 * np.pi
343
591
 
592
+ return lon, lat, lonu, latu, lonv, latv, lonq, latq
344
593
 
345
- def rot_sphere(lon1, lat1, rot):
346
594
 
347
- (n, m) = np.shape(lon1)
595
+ def _rot_sphere(lon, lat, rot):
596
+ (n, m) = np.shape(lon)
597
+ # convert rotation angle from degrees to radians
348
598
  rot = rot * np.pi / 180
349
599
 
350
- # translate into x,y,z
600
+ # translate into Cartesian coordinates x,y,z
351
601
  # conventions: (lon,lat) = (0,0) corresponds to (x,y,z) = ( 0,-r, 0)
352
602
  # (lon,lat) = (0,90) corresponds to (x,y,z) = ( 0, 0, r)
353
- x1 = np.sin(lon1) * np.cos(lat1)
354
- y1 = np.cos(lon1) * np.cos(lat1)
355
- z1 = np.sin(lat1)
603
+ x1 = np.sin(lon) * np.cos(lat)
604
+ y1 = np.cos(lon) * np.cos(lat)
605
+ z1 = np.sin(lat)
356
606
 
357
607
  # We will rotate these points around the small circle defined by
358
608
  # the intersection of the sphere and the plane that
@@ -379,38 +629,33 @@ def rot_sphere(lon1, lat1, rot):
379
629
  y2 = y1
380
630
  z2 = rp1 * np.sin(ap2)
381
631
 
382
- lon2 = np.pi / 2 * np.ones((n, m))
383
- lon2[abs(y2) > 1e-7] = np.arctan(
632
+ lon = np.pi / 2 * np.ones((n, m))
633
+ lon[abs(y2) > 1e-7] = np.arctan(
384
634
  np.abs(x2[np.abs(y2) > 1e-7] / y2[np.abs(y2) > 1e-7])
385
635
  )
386
- lon2[y2 < 0] = np.pi - lon2[y2 < 0]
387
- lon2[x2 < 0] = -lon2[x2 < 0]
636
+ lon[y2 < 0] = np.pi - lon[y2 < 0]
637
+ lon[x2 < 0] = -lon[x2 < 0]
388
638
 
389
639
  pr2 = np.sqrt(x2**2 + y2**2)
390
- lat2 = np.pi / 2 * np.ones((n, m))
391
- lat2[np.abs(pr2) > 1e-7] = np.arctan(
640
+ lat = np.pi / 2 * np.ones((n, m))
641
+ lat[np.abs(pr2) > 1e-7] = np.arctan(
392
642
  np.abs(z2[np.abs(pr2) > 1e-7] / pr2[np.abs(pr2) > 1e-7])
393
643
  )
394
- lat2[z2 < 0] = -lat2[z2 < 0]
395
-
396
- return (lon2, lat2)
397
-
644
+ lat[z2 < 0] = -lat[z2 < 0]
398
645
 
399
- def tra_sphere(lon1, lat1, tra):
646
+ return (lon, lat)
400
647
 
401
- # Rotate sphere around its y-axis
402
- # Part of easy grid
403
- # (c) 2008, Jeroen Molemaker, UCLA
404
648
 
405
- (n, m) = np.shape(lon1)
649
+ def _tra_sphere(lon, lat, tra):
650
+ (n, m) = np.shape(lon)
406
651
  tra = tra * np.pi / 180 # translation in latitude direction
407
652
 
408
653
  # translate into x,y,z
409
654
  # conventions: (lon,lat) = (0,0) corresponds to (x,y,z) = ( 0,-r, 0)
410
655
  # (lon,lat) = (0,90) corresponds to (x,y,z) = ( 0, 0, r)
411
- x1 = np.sin(lon1) * np.cos(lat1)
412
- y1 = np.cos(lon1) * np.cos(lat1)
413
- z1 = np.sin(lat1)
656
+ x1 = np.sin(lon) * np.cos(lat)
657
+ y1 = np.cos(lon) * np.cos(lat)
658
+ z1 = np.sin(lat)
414
659
 
415
660
  # We will rotate these points around the small circle defined by
416
661
  # the intersection of the sphere and the plane that
@@ -438,29 +683,29 @@ def tra_sphere(lon1, lat1, tra):
438
683
  z2 = rp1 * np.sin(ap2)
439
684
 
440
685
  ## transformation from (x,y,z) to (lat,lon)
441
- lon2 = np.pi / 2 * np.ones((n, m))
442
- lon2[np.abs(y2) > 1e-7] = np.arctan(
686
+ lon = np.pi / 2 * np.ones((n, m))
687
+ lon[np.abs(y2) > 1e-7] = np.arctan(
443
688
  np.abs(x2[np.abs(y2) > 1e-7] / y2[np.abs(y2) > 1e-7])
444
689
  )
445
- lon2[y2 < 0] = np.pi - lon2[y2 < 0]
446
- lon2[x2 < 0] = -lon2[x2 < 0]
690
+ lon[y2 < 0] = np.pi - lon[y2 < 0]
691
+ lon[x2 < 0] = -lon[x2 < 0]
447
692
 
448
693
  pr2 = np.sqrt(x2**2 + y2**2)
449
- lat2 = np.pi / (2 * np.ones((n, m)))
450
- lat2[np.abs(pr2) > 1e-7] = np.arctan(
694
+ lat = np.pi / (2 * np.ones((n, m)))
695
+ lat[np.abs(pr2) > 1e-7] = np.arctan(
451
696
  np.abs(z2[np.abs(pr2) > 1e-7] / pr2[np.abs(pr2) > 1e-7])
452
697
  )
453
- lat2[z2 < 0] = -lat2[z2 < 0]
698
+ lat[z2 < 0] = -lat[z2 < 0]
454
699
 
455
- return (lon2, lat2)
700
+ return (lon, lat)
456
701
 
457
702
 
458
- def _compute_coordinate_metrics(lon4, lonu, latu, lonv, latv):
703
+ def _compute_coordinate_metrics(lon, lonu, latu, lonv, latv):
459
704
  """Compute the curvilinear coordinate metrics pn and pm, defined as 1/grid spacing"""
460
705
 
461
706
  # pm = 1/dx
462
707
  pmu = gc_dist(lonu[:, :-1], latu[:, :-1], lonu[:, 1:], latu[:, 1:])
463
- pm = 0 * lon4
708
+ pm = 0 * lon
464
709
  pm[:, 1:-1] = pmu
465
710
  pm[:, 0] = pm[:, 1]
466
711
  pm[:, -1] = pm[:, -2]
@@ -468,7 +713,7 @@ def _compute_coordinate_metrics(lon4, lonu, latu, lonv, latv):
468
713
 
469
714
  # pn = 1/dy
470
715
  pnv = gc_dist(lonv[:-1, :], latv[:-1, :], lonv[1:, :], latv[1:, :])
471
- pn = 0 * lon4
716
+ pn = 0 * lon
472
717
  pn[1:-1, :] = pnv
473
718
  pn[0, :] = pn[1, :]
474
719
  pn[-1, :] = pn[-2, :]
@@ -478,7 +723,6 @@ def _compute_coordinate_metrics(lon4, lonu, latu, lonv, latv):
478
723
 
479
724
 
480
725
  def gc_dist(lon1, lat1, lon2, lat2):
481
-
482
726
  # Distance between 2 points along a great circle
483
727
  # lat and lon in radians!!
484
728
  # 2008, Jeroen Molemaker, UCLA
@@ -497,7 +741,7 @@ def gc_dist(lon1, lat1, lon2, lat2):
497
741
  return dis
498
742
 
499
743
 
500
- def _compute_angle(lon4, lonu, latu, lone):
744
+ def _compute_angle(lon, lonu, latu, lonq):
501
745
  """Compute angles of local grid positive x-axis relative to east"""
502
746
 
503
747
  dellat = latu[:, 1:] - latu[:, :-1]
@@ -506,7 +750,7 @@ def _compute_angle(lon4, lonu, latu, lone):
506
750
  dellon[dellon < -np.pi] = dellon[dellon < -np.pi] + 2 * np.pi
507
751
  dellon = dellon * np.cos(0.5 * (latu[:, 1:] + latu[:, :-1]))
508
752
 
509
- ang = copy.copy(lon4)
753
+ ang = copy.copy(lon)
510
754
  ang_s = np.arctan(dellat / (dellon + 1e-16))
511
755
  ang_s[(dellon < 0) & (dellat < 0)] = ang_s[(dellon < 0) & (dellat < 0)] - np.pi
512
756
  ang_s[(dellon < 0) & (dellat >= 0)] = ang_s[(dellon < 0) & (dellat >= 0)] + np.pi
@@ -517,36 +761,29 @@ def _compute_angle(lon4, lonu, latu, lone):
517
761
  ang[:, 0] = ang[:, 1]
518
762
  ang[:, -1] = ang[:, -2]
519
763
 
520
- lon4[lon4 < 0] = lon4[lon4 < 0] + 2 * np.pi
521
- lone[lone < 0] = lone[lone < 0] + 2 * np.pi
764
+ lon[lon < 0] = lon[lon < 0] + 2 * np.pi
765
+ lonq[lonq < 0] = lonq[lonq < 0] + 2 * np.pi
522
766
 
523
767
  return ang
524
768
 
525
769
 
526
770
  def _create_grid_ds(
527
- nx,
528
- ny,
529
771
  lon,
530
772
  lat,
531
773
  pm,
532
774
  pn,
533
775
  angle,
534
- size_x,
535
- size_y,
536
776
  rot,
537
777
  center_lon,
538
778
  center_lat,
539
- lone,
540
- late,
541
779
  ):
542
-
543
- # Coriolis frequency
544
- f0 = 4 * np.pi * np.sin(lat) / (24 * 3600)
545
-
546
- # Create empty xarray.Dataset object to store variables in
547
- ds = xr.Dataset()
548
-
549
- # TODO some of these variables are defined but never written to in Easy Grid
780
+ # Create xarray.Dataset object with lat_rho and lon_rho as coordinates
781
+ ds = xr.Dataset(
782
+ coords={
783
+ "lat_rho": (("eta_rho", "xi_rho"), lat * 180 / np.pi),
784
+ "lon_rho": (("eta_rho", "xi_rho"), lon * 180 / np.pi),
785
+ }
786
+ )
550
787
 
551
788
  ds["angle"] = xr.Variable(
552
789
  data=angle,
@@ -554,17 +791,16 @@ def _create_grid_ds(
554
791
  attrs={"long_name": "Angle between xi axis and east", "units": "radians"},
555
792
  )
556
793
 
557
- # ds['h'] = ...
558
- # TODO hraw comes from topography
559
- # ds['hraw'] = xr.Variable(data=hraw, dims=['eta_rho', 'xi_rho'])
794
+ # Coriolis frequency
795
+ f0 = 4 * np.pi * np.sin(lat) / (24 * 3600)
560
796
 
561
- ds["f0"] = xr.Variable(
797
+ ds["f"] = xr.Variable(
562
798
  data=f0,
563
799
  dims=["eta_rho", "xi_rho"],
564
800
  attrs={"long_name": "Coriolis parameter at rho-points", "units": "second-1"},
565
801
  )
566
- ds["pn"] = xr.Variable(
567
- data=pn,
802
+ ds["pm"] = xr.Variable(
803
+ data=pm,
568
804
  dims=["eta_rho", "xi_rho"],
569
805
  attrs={
570
806
  "long_name": "Curvilinear coordinate metric in xi-direction",
@@ -572,7 +808,7 @@ def _create_grid_ds(
572
808
  },
573
809
  )
574
810
  ds["pn"] = xr.Variable(
575
- data=pm,
811
+ data=pn,
576
812
  dims=["eta_rho", "xi_rho"],
577
813
  attrs={
578
814
  "long_name": "Curvilinear coordinate metric in eta-direction",
@@ -580,71 +816,120 @@ def _create_grid_ds(
580
816
  },
581
817
  )
582
818
 
819
+ ds["tra_lon"] = center_lon
820
+ ds["tra_lon"].attrs["long_name"] = "Longitudinal translation of base grid"
821
+ ds["tra_lon"].attrs["units"] = "degrees East"
822
+
823
+ ds["tra_lat"] = center_lat
824
+ ds["tra_lat"].attrs["long_name"] = "Latitudinal translation of base grid"
825
+ ds["tra_lat"].attrs["units"] = "degrees North"
826
+
827
+ ds["rotate"] = rot
828
+ ds["rotate"].attrs["long_name"] = "Rotation of base grid"
829
+ ds["rotate"].attrs["units"] = "degrees"
830
+
583
831
  ds["lon_rho"] = xr.Variable(
584
832
  data=lon * 180 / np.pi,
585
833
  dims=["eta_rho", "xi_rho"],
586
834
  attrs={"long_name": "longitude of rho-points", "units": "degrees East"},
587
835
  )
836
+
588
837
  ds["lat_rho"] = xr.Variable(
589
838
  data=lat * 180 / np.pi,
590
839
  dims=["eta_rho", "xi_rho"],
591
840
  attrs={"long_name": "latitude of rho-points", "units": "degrees North"},
592
841
  )
593
842
 
594
- ds["spherical"] = xr.Variable(
595
- data=["T"],
596
- dims=["one"],
597
- attrs={
598
- "long_name": "Grid type logical switch",
599
- "option_T": "spherical",
600
- },
601
- )
843
+ ds = _add_lat_lon_at_velocity_points(ds)
602
844
 
603
- # TODO this mask is obtained from hraw
604
- # ds['mask_rho'] = xr.Variable(data=lat * 180 / np.pi, dims=['eta_rho', 'xi_rho'], attrs={'long_name': "latitude of rho-points", 'units': "degrees North"})
845
+ return ds
605
846
 
606
- # TODO this 'one' dimension is completely unneccessary as netCDF can store scalars
607
- ds["tra_lon"] = xr.Variable(
608
- data=[center_lon],
609
- dims=["one"],
610
- attrs={
611
- "long_name": "Longitudinal translation of base grid",
612
- "units": "degrees East",
613
- },
614
- )
615
- ds["tra_lat"] = xr.Variable(
616
- data=[center_lat],
617
- dims=["one"],
618
- attrs={
619
- "long_name": "Latitudinal translation of base grid",
620
- "units": "degrees North",
621
- },
622
- )
623
- ds["rotate"] = xr.Variable(
624
- data=[rot],
625
- dims=["one"],
626
- attrs={"long_name": "Rotation of base grid", "units": "degrees"},
627
- )
628
847
 
629
- # TODO this is never written to
630
- # ds['xy_flip']
848
+ def _add_global_metadata(ds, size_x, size_y):
849
+ ds.attrs["title"] = "ROMS grid created by ROMS-Tools"
850
+
851
+ # Include the version of roms-tools
852
+ try:
853
+ roms_tools_version = importlib.metadata.version("roms-tools")
854
+ except importlib.metadata.PackageNotFoundError:
855
+ roms_tools_version = "unknown"
856
+
857
+ ds.attrs["roms_tools_version"] = roms_tools_version
858
+ ds.attrs["size_x"] = size_x
859
+ ds.attrs["size_y"] = size_y
631
860
 
632
861
  return ds
633
862
 
634
863
 
635
- def _add_global_metadata(ds, nx, ny, size_x, size_y, center_lon, center_lat, rot):
864
+ def _f2c(f):
865
+ """
866
+ Coarsen input xarray DataArray f in both x- and y-direction.
867
+
868
+ Parameters
869
+ ----------
870
+ f : xarray.DataArray
871
+ Input DataArray with dimensions (nxp, nyp).
636
872
 
637
- ds.attrs["Title"] = (
638
- "ROMS grid. Settings:"
639
- f" nx: {nx} ny: {ny} "
640
- f" xsize: {size_x / 1e3} ysize: {size_y / 1e3}"
641
- f" rotate: {rot} Lon: {center_lon} Lat: {center_lat}"
642
- )
643
- ds.attrs["Date"] = date.today()
644
- ds.attrs["Type"] = "ROMS grid produced by roms-tools"
873
+ Returns
874
+ -------
875
+ fc : xarray.DataArray
876
+ Output DataArray with modified dimensions and values.
877
+ """
645
878
 
646
- return ds
879
+ fc = _f2c_xdir(f)
880
+ fc = fc.transpose()
881
+ fc = _f2c_xdir(fc)
882
+ fc = fc.transpose()
883
+ fc = fc.rename({"eta_rho": "eta_coarse", "xi_rho": "xi_coarse"})
884
+
885
+ return fc
647
886
 
648
887
 
649
- def _make_topography(ds):
650
- ...
888
+ def _f2c_xdir(f):
889
+ """
890
+ Coarsen input xarray DataArray f in x-direction.
891
+
892
+ Parameters
893
+ ----------
894
+ f : xarray.DataArray
895
+ Input DataArray with dimensions (nxp, nyp).
896
+
897
+ Returns
898
+ -------
899
+ fc : xarray.DataArray
900
+ Output DataArray with modified dimensions and values.
901
+ """
902
+ nxp, nyp = f.shape
903
+ nxcp = (nxp - 2) // 2 + 2
904
+
905
+ fc = xr.DataArray(np.zeros((nxcp, nyp)), dims=f.dims)
906
+
907
+ # Calculate the interior values
908
+ fc[1:-1, :] = 0.5 * (f[1:-2:2, :] + f[2:-1:2, :])
909
+
910
+ # Calculate the first row
911
+ fc[0, :] = f[0, :] + 0.5 * (f[0, :] - f[1, :])
912
+
913
+ # Calculate the last row
914
+ fc[-1, :] = f[-1, :] + 0.5 * (f[-1, :] - f[-2, :])
915
+
916
+ return fc
917
+
918
+
919
+ def _add_lat_lon_at_velocity_points(ds):
920
+
921
+ lat_u = interpolate_from_rho_to_u(ds["lat_rho"])
922
+ lon_u = interpolate_from_rho_to_u(ds["lon_rho"])
923
+ lat_v = interpolate_from_rho_to_v(ds["lat_rho"])
924
+ lon_v = interpolate_from_rho_to_v(ds["lon_rho"])
925
+
926
+ lat_u.attrs = {"long_name": "latitude of u-points", "units": "degrees North"}
927
+ lon_u.attrs = {"long_name": "longitude of u-points", "units": "degrees East"}
928
+ lat_v.attrs = {"long_name": "latitude of v-points", "units": "degrees North"}
929
+ lon_v.attrs = {"long_name": "longitude of v-points", "units": "degrees East"}
930
+
931
+ ds = ds.assign_coords(
932
+ {"lat_u": lat_u, "lon_u": lon_u, "lat_v": lat_v, "lon_v": lon_v}
933
+ )
934
+
935
+ return ds