roms-tools 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
roms_tools/setup/grid.py CHANGED
@@ -1,51 +1,93 @@
1
1
  import copy
2
2
  from dataclasses import dataclass, field
3
- from datetime import date
4
3
 
5
4
  import numpy as np
6
5
  import xarray as xr
7
6
 
8
- from typing import Any
9
7
 
8
+ from roms_tools.setup.topography import _add_topography_and_mask
9
+ from roms_tools.setup.plot import _plot
10
10
 
11
11
  RADIUS_OF_EARTH = 6371315.0 # in m
12
12
 
13
13
 
14
- # TODO lat_rho and lon_rho should be coordinate variables
15
-
16
14
  # TODO should we store an xgcm.Grid object instead of an xarray Dataset? Or even subclass xgcm.Grid?
17
15
 
18
16
 
19
17
  @dataclass(frozen=True, kw_only=True)
20
18
  class Grid:
21
19
  """
22
- A single ROMS grid.
23
-
24
- Used for creating, plotting, and then saving a new ROMS domain grid.
25
-
26
- Parameters
27
- ----------
28
- nx
29
- Number of grid points in the x-direction
30
- ny
31
- Number of grid points in the y-direction
32
- size_x
33
- Domain size in the x-direction (in km?)
34
- size_y
35
- Domain size in the y-direction (in km?)
36
- center_lon
37
- Longitude of grid center
38
- center_lat
39
- Latitude of grid center
40
- rot
41
- Rotation of grid x-direction from lines of constant latitude.
42
- Measured in degrees, with positive values meaning a counterclockwise rotation.
43
- The default is 0, which means that the x-direction of the grid x-direction is aligned with lines of constant latitude.
44
-
45
- Raises
46
- ------
47
- ValueError
48
- If you try to create a grid which crosses the Greenwich Meridian
20
+ A single ROMS grid.
21
+
22
+ Used for creating, plotting, and then saving a new ROMS domain grid.
23
+
24
+ Parameters
25
+ ----------
26
+ nx : int
27
+ Number of grid points in the x-direction.
28
+ ny : int
29
+ Number of grid points in the y-direction.
30
+ size_x : float
31
+ Domain size in the x-direction (in kilometers).
32
+ size_y : float
33
+ Domain size in the y-direction (in kilometers).
34
+ center_lon : float
35
+ Longitude of grid center.
36
+ center_lat : float
37
+ Latitude of grid center.
38
+ rot : float, optional
39
+ Rotation of grid x-direction from lines of constant latitude, measured in degrees.
40
+ Positive values represent a counterclockwise rotation.
41
+ The default is 0, which means that the x-direction of the grid is aligned with lines of constant latitude.
42
+ topography_source : str, optional
43
+ Specifies the data source to use for the topography. Options are
44
+ "etopo5". The default is "etopo5".
45
+ smooth_factor : float, optional
46
+ The smoothing factor used in the domain-wide Gaussian smoothing of the
47
+ topography. Smaller values result in less smoothing, while larger
48
+ values produce more smoothing. The default is 8.
49
+ hmin : float, optional
50
+ The minimum ocean depth (in meters). The default is 5.
51
+ rmax : float, optional
52
+ The maximum slope parameter (in meters). This parameter controls
53
+ the local smoothing of the topography. Smaller values result in
54
+ smoother topography, while larger values preserve more detail.
55
+ The default is 0.2.
56
+
57
+ Attributes
58
+ ----------
59
+ nx : int
60
+ Number of grid points in the x-direction.
61
+ ny : int
62
+ Number of grid points in the y-direction.
63
+ size_x : float
64
+ Domain size in the x-direction (in kilometers).
65
+ size_y : float
66
+ Domain size in the y-direction (in kilometers).
67
+ center_lon : float
68
+ Longitude of grid center.
69
+ center_lat : float
70
+ Latitude of grid center.
71
+ rot : float
72
+ Rotation of grid x-direction from lines of constant latitude.
73
+ topography_source : str
74
+ Data source used for the topography.
75
+ smooth_factor : int
76
+ Smoothing factor used in the domain-wide Gaussian smoothing of the topography.
77
+ hmin : float
78
+ Minimum ocean depth (in meters).
79
+ rmax : float
80
+ Maximum slope parameter (in meters).
81
+ ds : xr.Dataset
82
+ The xarray Dataset containing the grid data.
83
+ straddle : bool
84
+ Indicates if the Greenwich meridian (0° longitude) intersects the domain.
85
+ `True` if it does, `False` otherwise.
86
+
87
+ Raises
88
+ ------
89
+ ValueError
90
+ If you try to create a grid with domain size larger than 20000 km.
49
91
  """
50
92
 
51
93
  nx: int
@@ -55,7 +97,12 @@ class Grid:
55
97
  center_lon: float
56
98
  center_lat: float
57
99
  rot: float = 0
100
+ topography_source: str = "etopo5"
101
+ smooth_factor: int = 8
102
+ hmin: float = 5.0
103
+ rmax: float = 0.2
58
104
  ds: xr.Dataset = field(init=False, repr=False)
105
+ straddle: bool = field(init=False, repr=False)
59
106
 
60
107
  def __post_init__(self):
61
108
  ds = _make_grid_ds(
@@ -71,6 +118,82 @@ class Grid:
71
118
  # see https://stackoverflow.com/questions/53756788/how-to-set-the-value-of-dataclass-field-in-post-init-when-frozen-true
72
119
  object.__setattr__(self, "ds", ds)
73
120
 
121
+ # Update self.ds with topography and mask information
122
+ self.add_topography_and_mask(
123
+ topography_source=self.topography_source,
124
+ smooth_factor=self.smooth_factor,
125
+ hmin=self.hmin,
126
+ rmax=self.rmax,
127
+ )
128
+
129
+ # Check if the Greenwich meridian goes through the domain.
130
+ self._straddle()
131
+
132
+ def add_topography_and_mask(
133
+ self, topography_source="etopo5", smooth_factor=8, hmin=5.0, rmax=0.2
134
+ ) -> None:
135
+ """
136
+ Add topography and mask to the grid dataset.
137
+
138
+ This method processes the topography data and generates a land/sea mask.
139
+ It applies several steps, including interpolating topography, smoothing
140
+ the topography over the entire domain and locally, and filling in enclosed basins. The
141
+ processed topography and mask are added to the grid's dataset as new variables.
142
+
143
+ Parameters
144
+ ----------
145
+ topography_source : str, optional
146
+ Specifies the data source to use for the topography. Options are
147
+ "etopo5". The default is "etopo5".
148
+ smooth_factor : float, optional
149
+ The smoothing factor used in the domain-wide Gaussian smoothing of the
150
+ topography. Smaller values result in less smoothing, while larger
151
+ values produce more smoothing. The default is 8.
152
+ hmin : float, optional
153
+ The minimum ocean depth (in meters). The default is 5.
154
+ rmax : float, optional
155
+ The maximum slope parameter (in meters). This parameter controls
156
+ the local smoothing of the topography. Smaller values result in
157
+ smoother topography, while larger values preserve more detail.
158
+ The default is 0.2.
159
+
160
+ Returns
161
+ -------
162
+ None
163
+ This method modifies the dataset in place and does not return a value.
164
+ """
165
+
166
+ ds = _add_topography_and_mask(
167
+ self.ds, topography_source, smooth_factor, hmin, rmax
168
+ )
169
+ # Assign the updated dataset back to the frozen dataclass
170
+ object.__setattr__(self, "ds", ds)
171
+
172
+ def compute_bathymetry_laplacian(self):
173
+ """
174
+ Compute the Laplacian of the 'h' field in the provided grid dataset.
175
+
176
+ Adds:
177
+ xarray.DataArray: The Laplacian of the 'h' field as a new variable in the dataset self.ds.
178
+ """
179
+
180
+ # Extract the 'h' field and grid spacing variables
181
+ h = self.ds.h
182
+ pm = self.ds.pm # Reciprocal of grid spacing in x-direction
183
+ pn = self.ds.pn # Reciprocal of grid spacing in y-direction
184
+
185
+ # Compute second derivatives using finite differences
186
+ d2h_dx2 = (h.shift(xi_rho=-1) - 2 * h + h.shift(xi_rho=1)) * pm**2
187
+ d2h_dy2 = (h.shift(eta_rho=-1) - 2 * h + h.shift(eta_rho=1)) * pn**2
188
+
189
+ # Compute the Laplacian by summing second derivatives
190
+ laplacian_h = d2h_dx2 + d2h_dy2
191
+
192
+ # Add the Laplacian as a new variable in the dataset
193
+ self.ds["h_laplacian"] = laplacian_h
194
+ self.ds["h_laplacian"].attrs["long_name"] = "Laplacian of final bathymetry"
195
+ self.ds["h_laplacian"].attrs["units"] = "1/m"
196
+
74
197
  def save(self, filepath: str) -> None:
75
198
  """
76
199
  Save the grid information to a netCDF4 file.
@@ -81,78 +204,158 @@ class Grid:
81
204
  """
82
205
  self.ds.to_netcdf(filepath)
83
206
 
84
- def from_file(self, filepath: str) -> "Grid":
207
+ @classmethod
208
+ def from_file(cls, filepath: str) -> "Grid":
85
209
  """
86
- Open an existing grid from a file.
210
+ Create a Grid instance from an existing file.
87
211
 
88
212
  Parameters
89
213
  ----------
90
- filepath
214
+ filepath : str
215
+ Path to the file containing the grid information.
216
+
217
+ Returns
218
+ -------
219
+ Grid
220
+ A new instance of Grid populated with data from the file.
91
221
  """
92
- # TODO set other parameters that were saved into the file, because every parameter we need gets saved.
93
- # TODO actually we will need to deduce size_x and size_y from the file, that's annoying.
94
- self.ds = xr.open_dataset(filepath)
95
- raise NotImplementedError()
222
+ # Load the dataset from the file
223
+ ds = xr.open_dataset(filepath)
224
+
225
+ # Create a new Grid instance without calling __init__ and __post_init__
226
+ grid = cls.__new__(cls)
227
+
228
+ # Set the dataset for the grid instance
229
+ object.__setattr__(grid, "ds", ds)
230
+
231
+ # Check if the Greenwich meridian goes through the domain.
232
+ grid._straddle()
233
+
234
+ # Manually set the remaining attributes by extracting parameters from dataset
235
+ object.__setattr__(grid, "nx", ds.sizes["xi_rho"] - 2)
236
+ object.__setattr__(grid, "ny", ds.sizes["eta_rho"] - 2)
237
+ object.__setattr__(grid, "center_lon", ds["tra_lon"].values.item())
238
+ object.__setattr__(grid, "center_lat", ds["tra_lat"].values.item())
239
+ object.__setattr__(grid, "rot", ds["rotate"].values.item())
240
+
241
+ for attr in [
242
+ "size_x",
243
+ "size_y",
244
+ "topography_source",
245
+ "smooth_factor",
246
+ "hmin",
247
+ "rmax",
248
+ ]:
249
+ if attr in ds.attrs:
250
+ object.__setattr__(grid, attr, ds.attrs[attr])
251
+
252
+ return grid
253
+
254
+ # override __repr__ method to only print attributes that are actually set
255
+ def __repr__(self) -> str:
256
+ cls = self.__class__
257
+ cls_name = cls.__name__
258
+ # Create a dictionary of attribute names and values, filtering out those that are not set and 'ds'
259
+ attr_dict = {
260
+ k: v for k, v in self.__dict__.items() if k != "ds" and v is not None
261
+ }
262
+ attr_str = ", ".join(f"{k}={v!r}" for k, v in attr_dict.items())
263
+ return f"{cls_name}({attr_str})"
264
+
265
+ # def to_xgcm() -> Any:
266
+ # # TODO we could convert the dataset to an xgcm.Grid object and return here?
267
+ # raise NotImplementedError()
268
+
269
+ def _straddle(self) -> None:
270
+ """
271
+ Check if the Greenwich meridian goes through the domain.
272
+
273
+ This method sets the `straddle` attribute to `True` if the Greenwich meridian
274
+ (0° longitude) intersects the domain defined by `lon_rho`. Otherwise, it sets
275
+ the `straddle` attribute to `False`.
96
276
 
97
- def to_xgcm() -> Any:
98
- # TODO we could convert the dataset to an xgcm.Grid object and return here?
99
- raise NotImplementedError()
277
+ The check is based on whether the longitudinal differences between adjacent
278
+ points exceed 300 degrees, indicating a potential wraparound of longitude.
279
+ """
280
+
281
+ if (
282
+ np.abs(self.ds.lon_rho.diff("xi_rho")).max() > 300
283
+ or np.abs(self.ds.lon_rho.diff("eta_rho")).max() > 300
284
+ ):
285
+ object.__setattr__(self, "straddle", True)
286
+ else:
287
+ object.__setattr__(self, "straddle", False)
100
288
 
101
289
  def plot(self, bathymetry: bool = False) -> None:
102
290
  """
103
291
  Plot the grid.
104
292
 
105
- Requires cartopy and matplotlib.
106
-
107
293
  Parameters
108
294
  ----------
109
- bathymetry: bool
295
+ bathymetry : bool
110
296
  Whether or not to plot the bathymetry. Default is False.
297
+
298
+ Returns
299
+ -------
300
+ None
301
+ This method does not return any value. It generates and displays a plot.
302
+
111
303
  """
112
304
 
113
- # TODO optionally plot topography on top?
114
305
  if bathymetry:
115
- raise NotImplementedError()
306
+ kwargs = {"cmap": "YlGnBu"}
307
+ _plot(
308
+ self.ds,
309
+ field=self.ds.h.where(self.ds.mask_rho),
310
+ straddle=self.straddle,
311
+ kwargs=kwargs,
312
+ )
313
+ else:
314
+ _plot(self.ds, straddle=self.straddle)
315
+
316
+ def coarsen(self):
317
+ """
318
+ Update the grid by adding grid variables that are coarsened versions of the original
319
+ fine-resoluion grid variables. The coarsening is by a factor of two.
320
+
321
+ The specific variables being coarsened are:
322
+ - `lon_rho` -> `lon_coarse`: Longitude at rho points.
323
+ - `lat_rho` -> `lat_coarse`: Latitude at rho points.
324
+ - `h` -> `h_coarse`: Bathymetry (depth).
325
+ - `angle` -> `angle_coarse`: Angle between the xi axis and true east.
326
+ - `mask_rho` -> `mask_coarse`: Land/sea mask at rho points.
327
+
328
+ Returns
329
+ -------
330
+ None
331
+
332
+ Modifies
333
+ --------
334
+ self.ds : xr.Dataset
335
+ The dataset attribute of the Grid instance is updated with the new coarser variables.
336
+ """
337
+ d = {
338
+ "lon_rho": "lon_coarse",
339
+ "lat_rho": "lat_coarse",
340
+ "h": "h_coarse",
341
+ "angle": "angle_coarse",
342
+ "mask_rho": "mask_coarse",
343
+ }
116
344
 
117
- import cartopy.crs as ccrs
118
- import matplotlib.pyplot as plt
345
+ for fine_var, coarse_var in d.items():
346
+ fine_field = self.ds[fine_var]
347
+ if self.straddle and fine_var == "lon_rho":
348
+ fine_field = xr.where(fine_field > 180, fine_field - 360, fine_field)
119
349
 
120
- lon_deg = (self.ds["lon_rho"] - 360).values
121
- lat_deg = self.ds["lat_rho"].values
350
+ coarse_field = _f2c(fine_field)
351
+ if fine_var == "lon_rho":
352
+ coarse_field = xr.where(
353
+ coarse_field < 0, coarse_field + 360, coarse_field
354
+ )
122
355
 
123
- # Define projections
124
- geodetic = ccrs.Geodetic()
125
- trans = ccrs.NearsidePerspective(
126
- central_longitude=np.mean(lon_deg), central_latitude=np.mean(lat_deg)
127
- )
356
+ self.ds[coarse_var] = coarse_field
128
357
 
129
- # find corners
130
- (lo1, la1) = (lon_deg[0, 0], lat_deg[0, 0])
131
- (lo2, la2) = (lon_deg[0, -1], lat_deg[0, -1])
132
- (lo3, la3) = (lon_deg[-1, -1], lat_deg[-1, -1])
133
- (lo4, la4) = (lon_deg[-1, 0], lat_deg[-1, 0])
134
-
135
- # transform coordinates to projected space
136
- lo1t, la1t = trans.transform_point(lo1, la1, geodetic)
137
- lo2t, la2t = trans.transform_point(lo2, la2, geodetic)
138
- lo3t, la3t = trans.transform_point(lo3, la3, geodetic)
139
- lo4t, la4t = trans.transform_point(lo4, la4, geodetic)
140
-
141
- plt.figure(figsize=(10, 10))
142
- ax = plt.axes(projection=trans)
143
-
144
- ax.plot(
145
- [lo1t, lo2t, lo3t, lo4t, lo1t],
146
- [la1t, la2t, la3t, la4t, la1t],
147
- "ro-",
148
- )
149
-
150
- ax.coastlines(
151
- resolution="50m", linewidth=0.5, color="black"
152
- ) # add map of coastlines
153
- ax.gridlines()
154
-
155
- plt.show()
358
+ self.ds["mask_coarse"] = xr.where(self.ds["mask_coarse"] > 0.5, 1, 0)
156
359
 
157
360
 
158
361
  def _make_grid_ds(
@@ -164,195 +367,145 @@ def _make_grid_ds(
164
367
  center_lat: float,
165
368
  rot: float,
166
369
  ) -> xr.Dataset:
370
+ _raise_if_domain_size_too_large(size_x, size_y)
167
371
 
168
- if size_y > size_x:
169
- domain_length, domain_width = size_x * 1e3, size_y * 1e3 # in m
170
- nl, nw = nx, ny
171
- else:
172
- domain_length, domain_width = size_y * 1e3, size_x * 1e3 # in m
173
- nl, nw = ny, nx
174
-
175
- initial_lon_lat_vars = _make_initial_lon_lat_ds(domain_length, domain_width, nl, nw)
372
+ initial_lon_lat_vars = _make_initial_lon_lat_ds(size_x, size_y, nx, ny)
176
373
 
374
+ # rotate coordinate system
177
375
  rotated_lon_lat_vars = _rotate(*initial_lon_lat_vars, rot)
376
+ lon, *_ = rotated_lon_lat_vars
178
377
 
179
- lon2, *_ = rotated_lon_lat_vars
180
-
181
- _raise_if_crosses_greenwich_meridian(lon2, center_lon)
182
-
378
+ # translate coordinate system
183
379
  translated_lon_lat_vars = _translate(*rotated_lon_lat_vars, center_lat, center_lon)
380
+ lon, lat, lonu, latu, lonv, latv, lonq, latq = translated_lon_lat_vars
184
381
 
185
- lon4, lat4, lonu, latu, lonv, latv, lone, late = translated_lon_lat_vars
186
-
187
- pm, pn = _compute_coordinate_metrics(lon4, lonu, latu, lonv, latv)
188
-
189
- ang = _compute_angle(lon4, lonu, latu, lone)
190
-
191
- ds = _create_grid_ds(
192
- nx,
193
- ny,
194
- lon4,
195
- lat4,
196
- pm,
197
- pn,
198
- ang,
199
- size_x,
200
- size_y,
201
- rot,
202
- center_lon,
203
- center_lat,
204
- lone,
205
- late,
206
- )
207
-
208
- # TODO topography
209
- # ds = _make_topography(ds)
382
+ # compute 1/dx and 1/dy
383
+ pm, pn = _compute_coordinate_metrics(lon, lonu, latu, lonv, latv)
210
384
 
211
- ds = _add_global_metadata(ds, nx, ny, size_x, size_y, center_lon, center_lat, rot)
385
+ # compute angle of local grid positive x-axis relative to east
386
+ ang = _compute_angle(lon, lonu, latu, lonq)
212
387
 
213
- return ds
388
+ ds = _create_grid_ds(lon, lat, pm, pn, ang, rot, center_lon, center_lat)
214
389
 
390
+ ds = _add_global_metadata(ds, size_x, size_y)
215
391
 
216
- def _raise_if_crosses_greenwich_meridian(lon, center_lon):
217
- # We have to do this before the grid is translated because we don't trust the grid creation routines in that case.
392
+ return ds
218
393
 
219
- # TODO it would be nice to handle this case, but we first need to know what ROMS expects / can handle.
220
394
 
221
- # TODO what about grids which cross the international dateline?
395
+ def _raise_if_domain_size_too_large(size_x, size_y):
396
+ threshold = 20000
397
+ if size_x > threshold or size_y > threshold:
398
+ raise ValueError("Domain size has to be smaller than %g km" % threshold)
222
399
 
223
- if np.min(lon + center_lon) < 0 < np.max(lon + center_lon):
224
- raise ValueError("Grid cannot cross Greenwich Meridian")
225
400
 
401
+ def _make_initial_lon_lat_ds(size_x, size_y, nx, ny):
402
+ # Mercator projection around the equator
226
403
 
227
- def _make_initial_lon_lat_ds(domain_length, domain_width, nl, nw):
404
+ # initially define the domain to be longer in x-direction (dimension "length")
405
+ # than in y-direction (dimension "width") to keep grid distortion minimal
406
+ if size_y > size_x:
407
+ domain_length, domain_width = size_y * 1e3, size_x * 1e3 # in m
408
+ nl, nw = ny, nx
409
+ else:
410
+ domain_length, domain_width = size_x * 1e3, size_y * 1e3 # in m
411
+ nl, nw = nx, ny
228
412
 
229
- domain_length_in_degrees_longitude = domain_length / RADIUS_OF_EARTH
230
- domain_width_in_degrees_latitude = domain_width / RADIUS_OF_EARTH
413
+ domain_length_in_degrees = domain_length / RADIUS_OF_EARTH
414
+ domain_width_in_degrees = domain_width / RADIUS_OF_EARTH
231
415
 
232
- longitude_array_1d_in_degrees = (
233
- domain_length_in_degrees_longitude * np.arange(-0.5, nl + 1.5, 1) / nl
234
- - domain_length_in_degrees_longitude / 2
416
+ # 1d array describing the longitudes at cell centers
417
+ x = np.arange(-0.5, nl + 1.5, 1)
418
+ lon_array_1d_in_degrees = (
419
+ domain_length_in_degrees * x / nl - domain_length_in_degrees / 2
420
+ )
421
+ # 1d array describing the longitudes at cell corners (or vorticity points "q")
422
+ xq = np.arange(-1, nl + 2, 1)
423
+ lonq_array_1d_in_degrees_q = (
424
+ domain_length_in_degrees * xq / nl - domain_length_in_degrees / 2
235
425
  )
236
426
 
237
- # TODO I don't fully understand what this piece of code achieves
238
- mul = 1.0
239
- for it in range(1, 101):
427
+ # convert degrees latitude to y-coordinate using Mercator projection
428
+ y1 = np.log(np.tan(np.pi / 4 - domain_width_in_degrees / 4))
429
+ y2 = np.log(np.tan(np.pi / 4 + domain_width_in_degrees / 4))
240
430
 
241
- # convert degrees latitude to y-coordinate using Mercator projection
242
- y1 = np.log(np.tan(np.pi / 4 - domain_width_in_degrees_latitude / 4))
243
- y2 = np.log(np.tan(np.pi / 4 + domain_width_in_degrees_latitude / 4))
431
+ # linearly space points in y-space
432
+ y = (y2 - y1) * np.arange(-0.5, nw + 1.5, 1) / nw + y1
433
+ yq = (y2 - y1) * np.arange(-1, nw + 2) / nw + y1
244
434
 
245
- # linearly space points in y-space
246
- y = (y2 - y1) * np.arange(-0.5, nw + 1.5, 1) / nw + y1
435
+ # inverse Mercator projections
436
+ lat_array_1d_in_degrees = np.arctan(np.sinh(y))
437
+ latq_array_1d_in_degrees = np.arctan(np.sinh(yq))
247
438
 
248
- # convert back to longitude using inverse Mercator projection
249
- # lat1d = 2*np.arctan(np.exp(y)) - np.pi/2
250
- latitude_array_1d_in_degrees = np.arctan(np.sinh(y))
439
+ # 2d grid at cell centers
440
+ lon, lat = np.meshgrid(lon_array_1d_in_degrees, lat_array_1d_in_degrees)
441
+ # 2d grid at cell corners
442
+ lonq, latq = np.meshgrid(lonq_array_1d_in_degrees_q, latq_array_1d_in_degrees)
251
443
 
252
- # find width and height of new grid at central grid point in degrees
253
- latitude_array_1d_in_degrees_cen = 0.5 * (
254
- latitude_array_1d_in_degrees[int(np.round(nw / 2) + 1)]
255
- - latitude_array_1d_in_degrees[int(np.round(nw / 2) - 1)]
256
- )
257
- longitude_array_1d_in_degrees_cen = domain_length_in_degrees_longitude / nl
258
-
259
- # scale the domain width in degreees latitude somehow?
260
- mul = (
261
- latitude_array_1d_in_degrees_cen
262
- / longitude_array_1d_in_degrees_cen
263
- * domain_length_in_degrees_longitude
264
- / domain_width_in_degrees_latitude
265
- * nw
266
- / nl
267
- )
268
- latitude_array_1d_in_degrees = latitude_array_1d_in_degrees / mul
444
+ if size_y > size_x:
445
+ # Rotate grid by 90 degrees because until here the grid has been defined
446
+ # to be longer in x-direction than in y-direction
269
447
 
270
- # TODO what does the 'e' suffix mean?
271
- lon1de = (
272
- domain_length_in_degrees_longitude * np.arange(-1, nl + 2, 1) / nl
273
- - domain_length_in_degrees_longitude / 2
274
- )
275
- ye = (y2 - y1) * np.arange(-1, nw + 2) / nw + y1
276
- # lat1de = 2 * np.arctan(np.exp(ye)) - np.pi/2
277
- lat1de = np.arctan(np.sinh(ye))
278
- lat1de = lat1de / mul
448
+ lon, lat = _rot_sphere(lon, lat, 90)
449
+ lonq, latq = _rot_sphere(lonq, latq, 90)
279
450
 
280
- lon1, lat1 = np.meshgrid(
281
- longitude_array_1d_in_degrees, latitude_array_1d_in_degrees
282
- )
283
- lone, late = np.meshgrid(lon1de, lat1de)
284
- lonu = 0.5 * (lon1[:, :-1] + lon1[:, 1:])
285
- latu = 0.5 * (lat1[:, :-1] + lat1[:, 1:])
286
- lonv = 0.5 * (lon1[:-1, :] + lon1[1:, :])
287
- latv = 0.5 * (lat1[:-1, :] + lat1[1:, :])
288
-
289
- if domain_length > domain_width:
290
- # Rotate grid 90 degrees so that the width is now longer than the length
291
-
292
- lon1, lat1 = rot_sphere(lon1, lat1, 90)
293
- lonu, latu = rot_sphere(lonu, latu, 90)
294
- lonv, latv = rot_sphere(lonv, latv, 90)
295
- lone, late = rot_sphere(lone, late, 90)
296
-
297
- lon1 = np.transpose(np.flip(lon1, 0))
298
- lat1 = np.transpose(np.flip(lat1, 0))
299
- lone = np.transpose(np.flip(lone, 0))
300
- late = np.transpose(np.flip(late, 1))
301
-
302
- lonu_tmp = np.transpose(np.flip(lonv, 0))
303
- latu_tmp = np.transpose(np.flip(latv, 0))
304
- lonv = np.transpose(np.flip(lonu, 0))
305
- latv = np.transpose(np.flip(latu, 0))
306
- lonu = lonu_tmp
307
- latu = latu_tmp
451
+ lon = np.transpose(np.flip(lon, 0))
452
+ lat = np.transpose(np.flip(lat, 0))
453
+ lonq = np.transpose(np.flip(lonq, 0))
454
+ latq = np.transpose(np.flip(latq, 0))
455
+
456
+ # infer longitudes and latitudes at u- and v-points
457
+ lonu = 0.5 * (lon[:, :-1] + lon[:, 1:])
458
+ latu = 0.5 * (lat[:, :-1] + lat[:, 1:])
459
+ lonv = 0.5 * (lon[:-1, :] + lon[1:, :])
460
+ latv = 0.5 * (lat[:-1, :] + lat[1:, :])
308
461
 
309
462
  # TODO wrap up into temporary container Dataset object?
310
- return lon1, lat1, lonu, latu, lonv, latv, lone, late
463
+ return lon, lat, lonu, latu, lonv, latv, lonq, latq
311
464
 
312
465
 
313
- def _rotate(lon1, lat1, lonu, latu, lonv, latv, lone, late, rot):
466
+ def _rotate(lon, lat, lonu, latu, lonv, latv, lonq, latq, rot):
314
467
  """Rotate grid counterclockwise relative to surface of Earth by rot degrees"""
315
468
 
316
- (lon2, lat2) = rot_sphere(lon1, lat1, rot)
317
- (lonu, latu) = rot_sphere(lonu, latu, rot)
318
- (lonv, latv) = rot_sphere(lonv, latv, rot)
319
- (lone, late) = rot_sphere(lone, late, rot)
469
+ (lon, lat) = _rot_sphere(lon, lat, rot)
470
+ (lonu, latu) = _rot_sphere(lonu, latu, rot)
471
+ (lonv, latv) = _rot_sphere(lonv, latv, rot)
472
+ (lonq, latq) = _rot_sphere(lonq, latq, rot)
320
473
 
321
- return lon2, lat2, lonu, latu, lonv, latv, lone, late
474
+ return lon, lat, lonu, latu, lonv, latv, lonq, latq
322
475
 
323
476
 
324
- def _translate(lon2, lat2, lonu, latu, lonv, latv, lone, late, tra_lat, tra_lon):
477
+ def _translate(lon, lat, lonu, latu, lonv, latv, lonq, latq, tra_lat, tra_lon):
325
478
  """Translate grid so that the centre lies at the position (tra_lat, tra_lon)"""
326
479
 
327
- (lon3, lat3) = tra_sphere(lon2, lat2, tra_lat)
328
- (lonu, latu) = tra_sphere(lonu, latu, tra_lat)
329
- (lonv, latv) = tra_sphere(lonv, latv, tra_lat)
330
- (lone, late) = tra_sphere(lone, late, tra_lat)
480
+ (lon, lat) = _tra_sphere(lon, lat, tra_lat)
481
+ (lonu, latu) = _tra_sphere(lonu, latu, tra_lat)
482
+ (lonv, latv) = _tra_sphere(lonv, latv, tra_lat)
483
+ (lonq, latq) = _tra_sphere(lonq, latq, tra_lat)
331
484
 
332
- lon4 = lon3 + tra_lon * np.pi / 180
485
+ lon = lon + tra_lon * np.pi / 180
333
486
  lonu = lonu + tra_lon * np.pi / 180
334
487
  lonv = lonv + tra_lon * np.pi / 180
335
- lone = lone + tra_lon * np.pi / 180
336
- lon4[lon4 < -np.pi] = lon4[lon4 < -np.pi] + 2 * np.pi
488
+ lonq = lonq + tra_lon * np.pi / 180
489
+
490
+ lon[lon < -np.pi] = lon[lon < -np.pi] + 2 * np.pi
337
491
  lonu[lonu < -np.pi] = lonu[lonu < -np.pi] + 2 * np.pi
338
492
  lonv[lonv < -np.pi] = lonv[lonv < -np.pi] + 2 * np.pi
339
- lone[lone < -np.pi] = lone[lone < -np.pi] + 2 * np.pi
340
- lat4 = lat3
341
-
342
- return lon4, lat4, lonu, latu, lonv, latv, lone, late
493
+ lonq[lonq < -np.pi] = lonq[lonq < -np.pi] + 2 * np.pi
343
494
 
495
+ return lon, lat, lonu, latu, lonv, latv, lonq, latq
344
496
 
345
- def rot_sphere(lon1, lat1, rot):
346
497
 
347
- (n, m) = np.shape(lon1)
498
+ def _rot_sphere(lon, lat, rot):
499
+ (n, m) = np.shape(lon)
500
+ # convert rotation angle from degrees to radians
348
501
  rot = rot * np.pi / 180
349
502
 
350
- # translate into x,y,z
503
+ # translate into Cartesian coordinates x,y,z
351
504
  # conventions: (lon,lat) = (0,0) corresponds to (x,y,z) = ( 0,-r, 0)
352
505
  # (lon,lat) = (0,90) corresponds to (x,y,z) = ( 0, 0, r)
353
- x1 = np.sin(lon1) * np.cos(lat1)
354
- y1 = np.cos(lon1) * np.cos(lat1)
355
- z1 = np.sin(lat1)
506
+ x1 = np.sin(lon) * np.cos(lat)
507
+ y1 = np.cos(lon) * np.cos(lat)
508
+ z1 = np.sin(lat)
356
509
 
357
510
  # We will rotate these points around the small circle defined by
358
511
  # the intersection of the sphere and the plane that
@@ -379,38 +532,33 @@ def rot_sphere(lon1, lat1, rot):
379
532
  y2 = y1
380
533
  z2 = rp1 * np.sin(ap2)
381
534
 
382
- lon2 = np.pi / 2 * np.ones((n, m))
383
- lon2[abs(y2) > 1e-7] = np.arctan(
535
+ lon = np.pi / 2 * np.ones((n, m))
536
+ lon[abs(y2) > 1e-7] = np.arctan(
384
537
  np.abs(x2[np.abs(y2) > 1e-7] / y2[np.abs(y2) > 1e-7])
385
538
  )
386
- lon2[y2 < 0] = np.pi - lon2[y2 < 0]
387
- lon2[x2 < 0] = -lon2[x2 < 0]
539
+ lon[y2 < 0] = np.pi - lon[y2 < 0]
540
+ lon[x2 < 0] = -lon[x2 < 0]
388
541
 
389
542
  pr2 = np.sqrt(x2**2 + y2**2)
390
- lat2 = np.pi / 2 * np.ones((n, m))
391
- lat2[np.abs(pr2) > 1e-7] = np.arctan(
543
+ lat = np.pi / 2 * np.ones((n, m))
544
+ lat[np.abs(pr2) > 1e-7] = np.arctan(
392
545
  np.abs(z2[np.abs(pr2) > 1e-7] / pr2[np.abs(pr2) > 1e-7])
393
546
  )
394
- lat2[z2 < 0] = -lat2[z2 < 0]
395
-
396
- return (lon2, lat2)
547
+ lat[z2 < 0] = -lat[z2 < 0]
397
548
 
549
+ return (lon, lat)
398
550
 
399
- def tra_sphere(lon1, lat1, tra):
400
551
 
401
- # Rotate sphere around its y-axis
402
- # Part of easy grid
403
- # (c) 2008, Jeroen Molemaker, UCLA
404
-
405
- (n, m) = np.shape(lon1)
552
+ def _tra_sphere(lon, lat, tra):
553
+ (n, m) = np.shape(lon)
406
554
  tra = tra * np.pi / 180 # translation in latitude direction
407
555
 
408
556
  # translate into x,y,z
409
557
  # conventions: (lon,lat) = (0,0) corresponds to (x,y,z) = ( 0,-r, 0)
410
558
  # (lon,lat) = (0,90) corresponds to (x,y,z) = ( 0, 0, r)
411
- x1 = np.sin(lon1) * np.cos(lat1)
412
- y1 = np.cos(lon1) * np.cos(lat1)
413
- z1 = np.sin(lat1)
559
+ x1 = np.sin(lon) * np.cos(lat)
560
+ y1 = np.cos(lon) * np.cos(lat)
561
+ z1 = np.sin(lat)
414
562
 
415
563
  # We will rotate these points around the small circle defined by
416
564
  # the intersection of the sphere and the plane that
@@ -438,29 +586,29 @@ def tra_sphere(lon1, lat1, tra):
438
586
  z2 = rp1 * np.sin(ap2)
439
587
 
440
588
  ## transformation from (x,y,z) to (lat,lon)
441
- lon2 = np.pi / 2 * np.ones((n, m))
442
- lon2[np.abs(y2) > 1e-7] = np.arctan(
589
+ lon = np.pi / 2 * np.ones((n, m))
590
+ lon[np.abs(y2) > 1e-7] = np.arctan(
443
591
  np.abs(x2[np.abs(y2) > 1e-7] / y2[np.abs(y2) > 1e-7])
444
592
  )
445
- lon2[y2 < 0] = np.pi - lon2[y2 < 0]
446
- lon2[x2 < 0] = -lon2[x2 < 0]
593
+ lon[y2 < 0] = np.pi - lon[y2 < 0]
594
+ lon[x2 < 0] = -lon[x2 < 0]
447
595
 
448
596
  pr2 = np.sqrt(x2**2 + y2**2)
449
- lat2 = np.pi / (2 * np.ones((n, m)))
450
- lat2[np.abs(pr2) > 1e-7] = np.arctan(
597
+ lat = np.pi / (2 * np.ones((n, m)))
598
+ lat[np.abs(pr2) > 1e-7] = np.arctan(
451
599
  np.abs(z2[np.abs(pr2) > 1e-7] / pr2[np.abs(pr2) > 1e-7])
452
600
  )
453
- lat2[z2 < 0] = -lat2[z2 < 0]
601
+ lat[z2 < 0] = -lat[z2 < 0]
454
602
 
455
- return (lon2, lat2)
603
+ return (lon, lat)
456
604
 
457
605
 
458
- def _compute_coordinate_metrics(lon4, lonu, latu, lonv, latv):
606
+ def _compute_coordinate_metrics(lon, lonu, latu, lonv, latv):
459
607
  """Compute the curvilinear coordinate metrics pn and pm, defined as 1/grid spacing"""
460
608
 
461
609
  # pm = 1/dx
462
610
  pmu = gc_dist(lonu[:, :-1], latu[:, :-1], lonu[:, 1:], latu[:, 1:])
463
- pm = 0 * lon4
611
+ pm = 0 * lon
464
612
  pm[:, 1:-1] = pmu
465
613
  pm[:, 0] = pm[:, 1]
466
614
  pm[:, -1] = pm[:, -2]
@@ -468,7 +616,7 @@ def _compute_coordinate_metrics(lon4, lonu, latu, lonv, latv):
468
616
 
469
617
  # pn = 1/dy
470
618
  pnv = gc_dist(lonv[:-1, :], latv[:-1, :], lonv[1:, :], latv[1:, :])
471
- pn = 0 * lon4
619
+ pn = 0 * lon
472
620
  pn[1:-1, :] = pnv
473
621
  pn[0, :] = pn[1, :]
474
622
  pn[-1, :] = pn[-2, :]
@@ -478,7 +626,6 @@ def _compute_coordinate_metrics(lon4, lonu, latu, lonv, latv):
478
626
 
479
627
 
480
628
  def gc_dist(lon1, lat1, lon2, lat2):
481
-
482
629
  # Distance between 2 points along a great circle
483
630
  # lat and lon in radians!!
484
631
  # 2008, Jeroen Molemaker, UCLA
@@ -497,7 +644,7 @@ def gc_dist(lon1, lat1, lon2, lat2):
497
644
  return dis
498
645
 
499
646
 
500
- def _compute_angle(lon4, lonu, latu, lone):
647
+ def _compute_angle(lon, lonu, latu, lonq):
501
648
  """Compute angles of local grid positive x-axis relative to east"""
502
649
 
503
650
  dellat = latu[:, 1:] - latu[:, :-1]
@@ -506,7 +653,7 @@ def _compute_angle(lon4, lonu, latu, lone):
506
653
  dellon[dellon < -np.pi] = dellon[dellon < -np.pi] + 2 * np.pi
507
654
  dellon = dellon * np.cos(0.5 * (latu[:, 1:] + latu[:, :-1]))
508
655
 
509
- ang = copy.copy(lon4)
656
+ ang = copy.copy(lon)
510
657
  ang_s = np.arctan(dellat / (dellon + 1e-16))
511
658
  ang_s[(dellon < 0) & (dellat < 0)] = ang_s[(dellon < 0) & (dellat < 0)] - np.pi
512
659
  ang_s[(dellon < 0) & (dellat >= 0)] = ang_s[(dellon < 0) & (dellat >= 0)] + np.pi
@@ -517,36 +664,29 @@ def _compute_angle(lon4, lonu, latu, lone):
517
664
  ang[:, 0] = ang[:, 1]
518
665
  ang[:, -1] = ang[:, -2]
519
666
 
520
- lon4[lon4 < 0] = lon4[lon4 < 0] + 2 * np.pi
521
- lone[lone < 0] = lone[lone < 0] + 2 * np.pi
667
+ lon[lon < 0] = lon[lon < 0] + 2 * np.pi
668
+ lonq[lonq < 0] = lonq[lonq < 0] + 2 * np.pi
522
669
 
523
670
  return ang
524
671
 
525
672
 
526
673
  def _create_grid_ds(
527
- nx,
528
- ny,
529
674
  lon,
530
675
  lat,
531
676
  pm,
532
677
  pn,
533
678
  angle,
534
- size_x,
535
- size_y,
536
679
  rot,
537
680
  center_lon,
538
681
  center_lat,
539
- lone,
540
- late,
541
682
  ):
542
-
543
- # Coriolis frequency
544
- f0 = 4 * np.pi * np.sin(lat) / (24 * 3600)
545
-
546
- # Create empty xarray.Dataset object to store variables in
547
- ds = xr.Dataset()
548
-
549
- # TODO some of these variables are defined but never written to in Easy Grid
683
+ # Create xarray.Dataset object with lat_rho and lon_rho as coordinates
684
+ ds = xr.Dataset(
685
+ coords={
686
+ "lat_rho": (("eta_rho", "xi_rho"), lat * 180 / np.pi),
687
+ "lon_rho": (("eta_rho", "xi_rho"), lon * 180 / np.pi),
688
+ }
689
+ )
550
690
 
551
691
  ds["angle"] = xr.Variable(
552
692
  data=angle,
@@ -554,17 +694,16 @@ def _create_grid_ds(
554
694
  attrs={"long_name": "Angle between xi axis and east", "units": "radians"},
555
695
  )
556
696
 
557
- # ds['h'] = ...
558
- # TODO hraw comes from topography
559
- # ds['hraw'] = xr.Variable(data=hraw, dims=['eta_rho', 'xi_rho'])
697
+ # Coriolis frequency
698
+ f0 = 4 * np.pi * np.sin(lat) / (24 * 3600)
560
699
 
561
- ds["f0"] = xr.Variable(
700
+ ds["f"] = xr.Variable(
562
701
  data=f0,
563
702
  dims=["eta_rho", "xi_rho"],
564
703
  attrs={"long_name": "Coriolis parameter at rho-points", "units": "second-1"},
565
704
  )
566
- ds["pn"] = xr.Variable(
567
- data=pn,
705
+ ds["pm"] = xr.Variable(
706
+ data=pm,
568
707
  dims=["eta_rho", "xi_rho"],
569
708
  attrs={
570
709
  "long_name": "Curvilinear coordinate metric in xi-direction",
@@ -572,7 +711,7 @@ def _create_grid_ds(
572
711
  },
573
712
  )
574
713
  ds["pn"] = xr.Variable(
575
- data=pm,
714
+ data=pn,
576
715
  dims=["eta_rho", "xi_rho"],
577
716
  attrs={
578
717
  "long_name": "Curvilinear coordinate metric in eta-direction",
@@ -585,66 +724,86 @@ def _create_grid_ds(
585
724
  dims=["eta_rho", "xi_rho"],
586
725
  attrs={"long_name": "longitude of rho-points", "units": "degrees East"},
587
726
  )
727
+
588
728
  ds["lat_rho"] = xr.Variable(
589
729
  data=lat * 180 / np.pi,
590
730
  dims=["eta_rho", "xi_rho"],
591
731
  attrs={"long_name": "latitude of rho-points", "units": "degrees North"},
592
732
  )
593
733
 
594
- ds["spherical"] = xr.Variable(
595
- data=["T"],
596
- dims=["one"],
597
- attrs={
598
- "long_name": "Grid type logical switch",
599
- "option_T": "spherical",
600
- },
601
- )
602
-
603
- # TODO this mask is obtained from hraw
604
- # ds['mask_rho'] = xr.Variable(data=lat * 180 / np.pi, dims=['eta_rho', 'xi_rho'], attrs={'long_name': "latitude of rho-points", 'units': "degrees North"})
734
+ ds["tra_lon"] = center_lon
735
+ ds["tra_lon"].attrs["long_name"] = "Longitudinal translation of base grid"
736
+ ds["tra_lon"].attrs["units"] = "degrees East"
605
737
 
606
- # TODO this 'one' dimension is completely unneccessary as netCDF can store scalars
607
- ds["tra_lon"] = xr.Variable(
608
- data=[center_lon],
609
- dims=["one"],
610
- attrs={
611
- "long_name": "Longitudinal translation of base grid",
612
- "units": "degrees East",
613
- },
614
- )
615
- ds["tra_lat"] = xr.Variable(
616
- data=[center_lat],
617
- dims=["one"],
618
- attrs={
619
- "long_name": "Latitudinal translation of base grid",
620
- "units": "degrees North",
621
- },
622
- )
623
- ds["rotate"] = xr.Variable(
624
- data=[rot],
625
- dims=["one"],
626
- attrs={"long_name": "Rotation of base grid", "units": "degrees"},
627
- )
738
+ ds["tra_lat"] = center_lat
739
+ ds["tra_lat"].attrs["long_name"] = "Latitudinal translation of base grid"
740
+ ds["tra_lat"].attrs["units"] = "degrees North"
628
741
 
629
- # TODO this is never written to
630
- # ds['xy_flip']
742
+ ds["rotate"] = rot
743
+ ds["rotate"].attrs["long_name"] = "Rotation of base grid"
744
+ ds["rotate"].attrs["units"] = "degrees"
631
745
 
632
746
  return ds
633
747
 
634
748
 
635
- def _add_global_metadata(ds, nx, ny, size_x, size_y, center_lon, center_lat, rot):
636
-
637
- ds.attrs["Title"] = (
638
- "ROMS grid. Settings:"
639
- f" nx: {nx} ny: {ny} "
640
- f" xsize: {size_x / 1e3} ysize: {size_y / 1e3}"
641
- f" rotate: {rot} Lon: {center_lon} Lat: {center_lat}"
642
- )
643
- ds.attrs["Date"] = date.today()
749
+ def _add_global_metadata(ds, size_x, size_y):
644
750
  ds.attrs["Type"] = "ROMS grid produced by roms-tools"
751
+ ds.attrs["size_x"] = size_x
752
+ ds.attrs["size_y"] = size_y
645
753
 
646
754
  return ds
647
755
 
648
756
 
649
- def _make_topography(ds):
650
- ...
757
+ def _f2c(f):
758
+ """
759
+ Coarsen input xarray DataArray f in both x- and y-direction.
760
+
761
+ Parameters
762
+ ----------
763
+ f : xarray.DataArray
764
+ Input DataArray with dimensions (nxp, nyp).
765
+
766
+ Returns
767
+ -------
768
+ fc : xarray.DataArray
769
+ Output DataArray with modified dimensions and values.
770
+ """
771
+
772
+ fc = _f2c_xdir(f)
773
+ fc = fc.transpose()
774
+ fc = _f2c_xdir(fc)
775
+ fc = fc.transpose()
776
+ fc = fc.rename({"eta_rho": "eta_coarse", "xi_rho": "xi_coarse"})
777
+
778
+ return fc
779
+
780
+
781
+ def _f2c_xdir(f):
782
+ """
783
+ Coarsen input xarray DataArray f in x-direction.
784
+
785
+ Parameters
786
+ ----------
787
+ f : xarray.DataArray
788
+ Input DataArray with dimensions (nxp, nyp).
789
+
790
+ Returns
791
+ -------
792
+ fc : xarray.DataArray
793
+ Output DataArray with modified dimensions and values.
794
+ """
795
+ nxp, nyp = f.shape
796
+ nxcp = (nxp - 2) // 2 + 2
797
+
798
+ fc = xr.DataArray(np.zeros((nxcp, nyp)), dims=f.dims)
799
+
800
+ # Calculate the interior values
801
+ fc[1:-1, :] = 0.5 * (f[1:-2:2, :] + f[2:-1:2, :])
802
+
803
+ # Calculate the first row
804
+ fc[0, :] = f[0, :] + 0.5 * (f[0, :] - f[1, :])
805
+
806
+ # Calculate the last row
807
+ fc[-1, :] = f[-1, :] + 0.5 * (f[-1, :] - f[-2, :])
808
+
809
+ return fc