ladim 2.0.5__py3-none-any.whl → 2.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ladim/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- __version__ = '2.0.5'
1
+ __version__ = '2.0.6'
2
2
 
3
3
  from .main import main, run
ladim/forcing.py CHANGED
@@ -2,19 +2,15 @@ from .model import Model, Module
2
2
 
3
3
 
4
4
  class Forcing(Module):
5
- def __init__(self, model: Model):
6
- super().__init__(model)
7
-
8
5
  def velocity(self, X, Y, Z, tstep=0.0):
9
6
  raise NotImplementedError
10
7
 
11
8
 
12
9
  class RomsForcing(Forcing):
13
- def __init__(self, model: Model, file, variables=None, **conf):
10
+ def __init__(self, file, variables=None, **conf):
14
11
  """
15
12
  Forcing module which uses output data from the ROMS ocean model
16
13
 
17
- :param model: Parent model
18
14
  :param file: Glob pattern for the input files
19
15
  :param variables: A mapping of variable names to interpolation
20
16
  specifications. Each interpolaction specification consists of 0-4
@@ -34,14 +30,12 @@ class RomsForcing(Forcing):
34
30
 
35
31
  :param conf: Legacy config dict
36
32
  """
37
- super().__init__(model)
38
-
39
33
  # Apply default interpolation configs
40
34
  variables = variables or dict()
41
35
  default_vars = dict(u="xt", v="yt", w="zt", temp="xyzt", salt="xyzt")
42
36
  self.variables = {**default_vars, **variables}
43
37
 
44
- grid_ref = GridReference(model)
38
+ grid_ref = GridReference()
45
39
  legacy_conf = dict(
46
40
  gridforce=dict(
47
41
  input_file=file,
@@ -67,17 +61,19 @@ class RomsForcing(Forcing):
67
61
  # self.U = self.forcing.U
68
62
  # self.V = self.forcing.V
69
63
 
70
- def update(self):
71
- elapsed = self.model.solver.time - self.model.solver.start
72
- t = elapsed // self.model.solver.step
64
+ def update(self, model: Model):
65
+ elapsed = model.solver.time - model.solver.start
66
+ t = elapsed // model.solver.step
73
67
 
68
+ # noinspection PyProtectedMember
69
+ self.forcing._grid.modules = model
74
70
  self.forcing.update(t)
75
71
 
76
72
  # Update state variables by sampling the field
77
- x, y, z = self.model.state['X'], self.model.state['Y'], self.model.state['Z']
73
+ x, y, z = model.state['X'], model.state['Y'], model.state['Z']
78
74
  for v in self.variables:
79
- if v in self.model.state:
80
- self.model.state[v] = self.field(x, y, z, v)
75
+ if v in model.state:
76
+ model.state[v] = self.field(x, y, z, v)
81
77
 
82
78
  def velocity(self, X, Y, Z, tstep=0.0):
83
79
  return self.forcing.velocity(X, Y, Z, tstep=tstep)
@@ -90,8 +86,8 @@ class RomsForcing(Forcing):
90
86
 
91
87
 
92
88
  class GridReference:
93
- def __init__(self, modules: Model):
94
- self.modules = modules
89
+ def __init__(self):
90
+ self.modules = None
95
91
 
96
92
  def __getattr__(self, item):
97
93
  return getattr(self.modules.grid.grid, item)
ladim/grid.py CHANGED
@@ -1,9 +1,15 @@
1
- from .model import Model, Module
1
+ from .model import Module
2
+ import numpy as np
3
+ from typing import Sequence
4
+ from scipy.ndimage import map_coordinates
2
5
 
3
6
 
4
7
  class Grid(Module):
5
- def __init__(self, model: Model):
6
- super().__init__(model)
8
+ """
9
+ The grid class represents the coordinate system used for particle tracking.
10
+ It contains methods for converting between global coordinates (latitude,
11
+ longitude, depth and posix time) and internal coordinates.
12
+ """
7
13
 
8
14
  def ingrid(self, X, Y):
9
15
  raise NotImplementedError
@@ -20,18 +26,158 @@ class Grid(Module):
20
26
  def xy2ll(self, x, y):
21
27
  raise NotImplementedError
22
28
 
29
+ # --- MODERN METHODS ---
30
+
31
+ def dx(self, x: Sequence, y: Sequence) -> np.ndarray:
32
+ """
33
+ Metric scale factor in the X direction
34
+
35
+ The metric scale factor is defined such that if one moves
36
+ a small increment delta along the axis, then the distance
37
+ (in meters) equals scale_factor * delta.
38
+
39
+ :param x: X positions
40
+ :param y: Y positions
41
+ :return: Metric scale factor [in meters per grid unit]
42
+ """
43
+ raise NotImplementedError
44
+
45
+ def dy(self, x: Sequence, y: Sequence) -> np.ndarray:
46
+ """
47
+ Metric scale factor in the Y direction
48
+
49
+ The metric scale factor is defined such that if one moves
50
+ a small increment delta along the axis, then the distance
51
+ (in meters) equals scale_factor * delta.
52
+
53
+ :param x: X positions
54
+ :param y: Y positions
55
+ :return: Metric scale factor [in meters per grid unit]
56
+ """
57
+ raise NotImplementedError
58
+
59
+ def from_bearing(
60
+ self, x: Sequence, y: Sequence, b: Sequence
61
+ ) -> np.ndarray:
62
+ """
63
+ Azimutal angles from compass bearings.
64
+
65
+ A compass bearing (in degrees) is defined such that 0 is north, 90
66
+ is east, 180 is south and 270 is west. An azimutal vector angle is
67
+ defined such that 0 is pointing along the X axis, 90 is pointing
68
+ along the Y axis, 180 is pointing opposite the X axis and 270 is
69
+ pointing opposite the Y axis.
70
+
71
+ This function computes a set of azimutal vector angles from a set
72
+ of compass bearings and horizontal positions.
73
+
74
+ :param x: X positions
75
+ :param y: Y positions
76
+ :param b: Compass bearings [degrees]
77
+ :return: Azimutal vector angles [degrees]
78
+ """
79
+ raise NotImplementedError
80
+
81
+ def to_bearing(
82
+ self, x: Sequence, y: Sequence, az: Sequence
83
+ ) -> np.ndarray:
84
+ """
85
+ Azimutal angles from compass bearings.
86
+
87
+ A compass bearing (in degrees) is defined such that 0 is north, 90
88
+ is east, 180 is south and 270 is west. An azimutal vector angle is
89
+ defined such that 0 is pointing along the X axis, 90 is pointing
90
+ along the Y axis, 180 is pointing opposite the X axis and 270 is
91
+ pointing opposite the Y axis.
92
+
93
+ This function computes a set of compass bearings from a set
94
+ of azimutal vector angles and horizontal positions.
95
+
96
+ :param x: X positions
97
+ :param y: Y positions
98
+ :param az: Azimutal vector angles [degrees]
99
+ :return: Compass bearings [degrees]
100
+ """
101
+ raise NotImplementedError
102
+
103
+ def from_latlon(
104
+ self, lat: Sequence, lon: Sequence,
105
+ ) -> tuple[np.ndarray, np.ndarray]:
106
+ """
107
+ Horizontal coordinates from latitude and longitude
108
+
109
+ :param lat: Latitude [degrees north]
110
+ :param lon: Longitude [degrees east]
111
+ :return: A tuple (x, y) of horizontal coordinates
112
+ """
113
+ raise NotImplementedError
114
+
115
+ def to_latlon(
116
+ self, x: Sequence, y: Sequence,
117
+ ) -> tuple[np.ndarray, np.ndarray]:
118
+ """
119
+ Latitude and longitude from horizontal coordinates
120
+
121
+ :param x: X positions
122
+ :param y: Y positions
123
+ :return: A tuple (lat, lon) of latitude [degrees north] and longitude [degrees east]
124
+ """
125
+ raise NotImplementedError
126
+
127
+ def from_depth(
128
+ self, x: Sequence, y: Sequence, z: Sequence
129
+ ) -> np.ndarray:
130
+ """
131
+ Vertical coordinates from depth and horizontal coordinates.
132
+
133
+ :param x: X positions
134
+ :param y: Y positions
135
+ :param z: Depth below surface [m, positive downwards]
136
+ :return: Vertical coordinates
137
+ """
138
+ raise NotImplementedError
139
+
140
+ def to_depth(
141
+ self, x: Sequence, y: Sequence, s: Sequence
142
+ ) -> np.ndarray:
143
+ """
144
+ Depth from horizontal and vertical coordinates.
145
+
146
+ :param x: X positions
147
+ :param y: Y positions
148
+ :param s: Vertical coordinates
149
+ :return: Depth below surface [m, positive downwards]
150
+ """
151
+ raise NotImplementedError
152
+
153
+ def from_epoch(self, p: Sequence) -> np.ndarray:
154
+ """
155
+ Time coordinates from posix time
156
+
157
+ :param p: Posix time [seconds since 1970-01-01]
158
+ :return: Time coordinates
159
+ """
160
+ raise NotImplementedError
161
+
162
+ def to_epoch(self, t: Sequence) -> np.ndarray:
163
+ """
164
+ Posix time from time coordinates
165
+
166
+ :param t: Time coordinates
167
+ :return: Posix time [seconds since 1970-01-01]
168
+ """
169
+ raise NotImplementedError
170
+
23
171
 
24
172
  class RomsGrid(Grid):
25
173
  def __init__(
26
174
  self,
27
- model: Model,
28
175
  file: str,
29
176
  start_time=None,
30
177
  subgrid=None,
31
178
  legacy_module='ladim.gridforce.ROMS.Grid',
32
179
  **_,
33
180
  ):
34
- super().__init__(model)
35
181
 
36
182
  legacy_conf = dict(
37
183
  gridforce=dict(
@@ -87,3 +233,402 @@ class RomsGrid(Grid):
87
233
 
88
234
  def xy2ll(self, X, Y):
89
235
  return self.grid.xy2ll(X, Y)
236
+
237
+
238
+ class ArrayGrid(Grid):
239
+ def __init__(
240
+ self,
241
+ lat: np.ndarray | tuple[tuple] = ((), ),
242
+ lon: np.ndarray | tuple[tuple] = ((), ),
243
+ depth: np.ndarray | tuple[tuple[tuple]] = (((), ), ),
244
+ time: np.ndarray | tuple = (),
245
+ mask: np.ndarray | tuple[tuple] = ((), ),
246
+ ):
247
+ """
248
+ Define an array grid
249
+
250
+ The number of lattice points in the T (time), Z (depth), Y and X
251
+ dimensions are NT, NZ, NY and NX, respectively. It is assumed that
252
+
253
+ - The lat/lon coordinates are independent of the T and Z dimensions
254
+ - The depth is independent of the T dimension
255
+ - The time is independent of the X, Y and Z timensions
256
+ - Time values must be increasing
257
+ - Depth values must be decreasing with Z
258
+
259
+ :param lat: Latitude coordinates [degrees, NY * NX array]
260
+ :param lon: Longitude coordinates [degrees, NY * NX array]
261
+ :param depth: Depth below surface [meters, positive downwards, NZ * NY * NX array]
262
+ :param time: Time since 1970-01-01 [seconds, NT array]
263
+ :param mask: Zero at land positions (default all ones) [NY * NX array]
264
+ """
265
+ self.lat = np.asarray(lat, dtype='f8')
266
+ self.lon = np.asarray(lon, dtype='f8')
267
+ self.depth = np.asarray(depth, dtype='f4')
268
+ self.time = np.asarray(time, dtype='datetime64[s]').astype('int64')
269
+
270
+ if (not mask) or np.size(mask) == 0:
271
+ self.mask = np.ones(self.depth.shape[-2:], dtype='i2')
272
+ else:
273
+ self.mask = np.asarray(mask, dtype='i2')
274
+
275
+ self._cache_dict = dict()
276
+
277
+ if np.any(np.diff(self.time) <= 0):
278
+ raise ValueError('Time values must be increasing')
279
+ if np.any(self.depth[0] < self.depth[-1]):
280
+ raise ValueError('Depth values must be decreasing with Z')
281
+
282
+ def from_epoch(self, p: Sequence) -> np.ndarray:
283
+ return np.interp(x=p, xp=self.time, fp=np.arange(len(self.time)))
284
+
285
+ def to_epoch(self, t: Sequence) -> np.ndarray:
286
+ return map_coordinates(self.time, (t, ), order=1, mode='nearest')
287
+
288
+ def to_latlon(
289
+ self, x: Sequence, y: Sequence,
290
+ ) -> tuple[np.ndarray, np.ndarray]:
291
+ lat = map_coordinates(self.lat, (y, x), order=1, mode='nearest')
292
+ lon = map_coordinates(self.lon, (y, x), order=1, mode='nearest')
293
+ return lat, lon
294
+
295
+ def from_latlon(
296
+ self, lat: Sequence, lon: Sequence,
297
+ ) -> tuple[np.ndarray, np.ndarray]:
298
+ y, x = bilin_inv(f=lat, g=lon, F=self.lat, G=self.lon)
299
+ return x, y
300
+
301
+ def to_depth(
302
+ self, x: Sequence, y: Sequence, s: Sequence
303
+ ) -> np.ndarray:
304
+ mask = map_coordinates(self.mask, (y, x), order=0, mode='nearest')
305
+ depth = map_coordinates(self.depth, (s, y, x), order=1, mode='nearest')
306
+ depth[mask == 0] = 0
307
+ return depth
308
+
309
+ def from_depth(
310
+ self, x: Sequence, y: Sequence, z: Sequence
311
+ ) -> np.ndarray:
312
+ depths = bilinear_interp(self.depth, y, x)
313
+ idx, frac = array_lookup(
314
+ arr=-depths,
315
+ values=-np.asarray(z),
316
+ return_frac=True,
317
+ )
318
+
319
+ s = idx + frac
320
+ mask = map_coordinates(self.mask, (y, x), order=0, mode='nearest')
321
+ s[mask == 0] = 0
322
+ return s
323
+
324
+ def compute(self, key):
325
+ """
326
+ Cached computation of key variables
327
+
328
+ :param key: The variable to compute
329
+ :return: The computed variables
330
+ """
331
+ if key in self._cache_dict:
332
+ return self._cache_dict[key]
333
+
334
+ if key in ['dx', 'dy']:
335
+ dx, dy = compute_dx_dy(lat=self.lat, lon=self.lon)
336
+ self._cache_dict['dx'] = dx
337
+ self._cache_dict['dy'] = dy
338
+
339
+ elif key in ['latdiff_x', 'latdiff_y', 'londiff_x', 'londiff_y']:
340
+ lax = (self.lat[:, 1:] - self.lat[:, :-1]) * (np.pi / 180)
341
+ lox = (self.lon[:, 1:] - self.lon[:, :-1]) * (np.pi / 180)
342
+ lay = (self.lat[1:, :] - self.lat[:-1, :]) * (np.pi / 180)
343
+ loy = (self.lon[1:, :] - self.lon[:-1, :]) * (np.pi / 180)
344
+ self._cache_dict['latdiff_x'] = lax
345
+ self._cache_dict['londiff_x'] = lox
346
+ self._cache_dict['latdiff_y'] = lay
347
+ self._cache_dict['londiff_y'] = loy
348
+
349
+ return self._cache_dict[key]
350
+
351
+ def dx(self, x: Sequence, y: Sequence) -> np.ndarray:
352
+ x = np.asarray(x)
353
+ y = np.asarray(y)
354
+ dx = self.compute('dx')
355
+ coords = (y, x - 0.5) # Convert from 'rho' to 'u' coordinates
356
+ return map_coordinates(dx, coords, order=1, mode='nearest')
357
+
358
+ def dy(self, x: Sequence, y: Sequence) -> np.ndarray:
359
+ x = np.asarray(x)
360
+ y = np.asarray(y)
361
+ dy = self.compute('dy')
362
+ coords = (y - 0.5, x) # Convert from 'rho' to 'v' coordinates
363
+ return map_coordinates(dy, coords, order=1, mode='nearest')
364
+
365
+ def _latlondiff(self, x: Sequence, y: Sequence):
366
+ """
367
+ Compute latitude and longitude unit difference at selected points
368
+
369
+ Returns a tuple latdiff_x, latdiff_y, londiff_x, londiff_y. Together,
370
+ these variables tell how much the latitude and longitude increases when
371
+ moving by one grid cell in either the X or Y direction.
372
+
373
+ :param x: X coordinates of starting points
374
+ :param y: Y coordinates of starting points
375
+ :return: A tuple latdiff_x, latdiff_y, londiff_x, londiff_y
376
+ """
377
+ x = np.asarray(x)
378
+ y = np.asarray(y)
379
+
380
+ latdiff_xdir_grid = self.compute('latdiff_x')
381
+ londiff_xdir_grid = self.compute('londiff_x')
382
+ latdiff_ydir_grid = self.compute('latdiff_y')
383
+ londiff_ydir_grid = self.compute('londiff_y')
384
+
385
+ crd_x = (y, x - 0.5) # Convert from 'rho' to 'u' coordinates
386
+ crd_y = (y - 0.5, x) # Convert from 'rho' to 'v' coordinates
387
+ latdiff_xdir = map_coordinates(latdiff_xdir_grid, crd_x, order=1, mode='nearest')
388
+ londiff_xdir = map_coordinates(londiff_xdir_grid, crd_x, order=1, mode='nearest')
389
+ latdiff_ydir = map_coordinates(latdiff_ydir_grid, crd_y, order=1, mode='nearest')
390
+ londiff_ydir = map_coordinates(londiff_ydir_grid, crd_y, order=1, mode='nearest')
391
+
392
+ return latdiff_xdir, latdiff_ydir, londiff_xdir, londiff_ydir
393
+
394
+ def to_bearing(
395
+ self, x: Sequence, y: Sequence, az: Sequence
396
+ ) -> np.ndarray:
397
+ # Compute unit lat/lon difference in the x and y directions
398
+ latdiff_x, latdiff_y, londiff_x, londiff_y = self._latlondiff(x, y)
399
+
400
+ # Define directional vector 'p' which is defined on the x/y grid
401
+ az_radians = np.asarray(az) * (np.pi / 180)
402
+ p_x = np.cos(az_radians)
403
+ p_y = np.sin(az_radians)
404
+
405
+ # Define new vector 'q' which is defined on the lon/lat grid
406
+ # and has the same direction as 'p'
407
+ q_lat = p_x * latdiff_x + p_y * latdiff_y
408
+ q_lon = p_x * londiff_x + p_y * londiff_y
409
+
410
+ # Compute bearing
411
+ bearing_radians = np.atan2(q_lon, q_lat)
412
+ bearing = (bearing_radians * (180 / np.pi)) % 360
413
+ return bearing
414
+
415
+ def from_bearing(
416
+ self, x: Sequence, y: Sequence, b: Sequence
417
+ ) -> np.ndarray:
418
+ # Compute unit lat/lon difference in the x and y directions
419
+ latdiff_x, latdiff_y, londiff_x, londiff_y = self._latlondiff(x, y)
420
+
421
+ # Define directional vector 'q' which is defined on the lat/lon grid
422
+ bearing_radians = np.asarray(b) * (np.pi / 180)
423
+ q_lat = np.cos(bearing_radians)
424
+ q_lon = np.sin(bearing_radians)
425
+
426
+ # Define new vector 'p' which is defined on the x/y grid
427
+ # and has the same direction as 'q'
428
+ p_x = q_lat * londiff_y - q_lon * latdiff_y
429
+ p_y = -q_lat * londiff_x + q_lon * latdiff_x
430
+
431
+ # Compute azimuth
432
+ az_radians = np.atan2(p_y, p_x)
433
+ az = (az_radians * (180 / np.pi)) % 360
434
+ return az
435
+
436
+
437
+ def bilin_inv(f, g, F, G, maxiter=7, tol=1.0e-7) -> tuple[np.ndarray, np.ndarray]:
438
+ """
439
+ Inverse bilinear interpolation
440
+
441
+ ``f, g`` should be scalars or arrays of same shape
442
+
443
+ ``F, G`` should be 2D arrays of the same shape
444
+
445
+ :param f: Desired f value
446
+ :param g: Desired g value
447
+ :param F: Tabulated f values
448
+ :param G: Tabulated g values
449
+ :param maxiter: Maximum number of Newton iterations
450
+ :param tol: Maximum residual value
451
+ :return: A tuple ``(x, y)`` such that ``F[x, y] = f`` and ``G[x, y] = g``, when
452
+ linearly interpolated
453
+ """
454
+
455
+ imax, jmax = np.array(F.shape) - 1
456
+
457
+ f = np.asarray(f)
458
+ g = np.asarray(g)
459
+
460
+ # initial guess
461
+ x = np.zeros_like(f) + 0.5 * imax
462
+ y = np.zeros_like(f) + 0.5 * jmax
463
+
464
+ for t in range(maxiter):
465
+ i = np.minimum(imax - 1, x.astype("i4"))
466
+ j = np.minimum(jmax - 1, y.astype("i4"))
467
+
468
+ p, q = x - i, y - j
469
+
470
+ # Shorthands
471
+ F00 = F[i, j]
472
+ F01 = F[i, j+1]
473
+ F10 = F[i+1, j]
474
+ F11 = F[i+1, j+1]
475
+ G00 = G[i, j]
476
+ G01 = G[i, j+1]
477
+ G10 = G[i+1, j]
478
+ G11 = G[i+1, j+1]
479
+
480
+ # Bilinear estimate of F[x,y] and G[x,y]
481
+ Fs = (
482
+ (1 - p) * (1 - q) * F00
483
+ + p * (1 - q) * F10
484
+ + (1 - p) * q * F01
485
+ + p * q * F11
486
+ )
487
+ Gs = (
488
+ (1 - p) * (1 - q) * G00
489
+ + p * (1 - q) * G10
490
+ + (1 - p) * q * G01
491
+ + p * q * G11
492
+ )
493
+
494
+ H = (Fs - f) ** 2 + (Gs - g) ** 2
495
+
496
+ if np.all(H < tol**2):
497
+ break
498
+
499
+ # Estimate Jacobi matrix
500
+ Fx = (1 - q) * (F10 - F00) + q * (F11 - F01)
501
+ Fy = (1 - p) * (F01 - F00) + p * (F11 - F10)
502
+ Gx = (1 - q) * (G10 - G00) + q * (G11 - G01)
503
+ Gy = (1 - p) * (G01 - G00) + p * (G11 - G10)
504
+
505
+ # Newton-Raphson step
506
+ # Jinv = np.linalg.inv([[Fx, Fy], [Gx, Gy]])
507
+ # incr = - np.dot(Jinv, [Fs-f, Gs-g])
508
+ # x = x + incr[0], y = y + incr[1]
509
+ det = Fx * Gy - Fy * Gx
510
+ x -= (Gy * (Fs - f) - Fy * (Gs - g)) / det
511
+ y -= (-Gx * (Fs - f) + Fx * (Gs - g)) / det
512
+
513
+ x = np.maximum(0, np.minimum(imax, x))
514
+ y = np.maximum(0, np.minimum(jmax, y))
515
+
516
+ return x, y
517
+
518
+
519
+ def array_lookup(arr, values, return_frac=False):
520
+ """
521
+ Find indices of a set of values
522
+
523
+ The lookup table "arr" has dimensions N * M, and should be sorted
524
+ along the M axis (i.e., arr[i, j] <= arr[i, j + 1] for all i, j)
525
+
526
+ The search value array "values" should have dimensions N. The values
527
+ are clipped by the minimum and maximum values given by "arr".
528
+
529
+ The function returns and index array "idx" such that arr[i, idx[i]] <=
530
+ values[i] < arr[i, idx[i] + 1] for all i. All values in "idx" are
531
+ values between 0 and M - 2.
532
+
533
+ If the parameter "return_frac" is set to True, the function returns an
534
+ additional array "frac" with values in the range [0, 1] such that
535
+ values == arr[i, idx[i]] * (1 - frac[i]) + arr[i, idx[i + 1]] * frac[i]
536
+ for all i.
537
+
538
+ :param arr: Lookup table, shape N * M
539
+ :param values: Values to search for, shape N
540
+ :param return_frac: True if interpolation index "frac" should be returned
541
+ :return: A tuple ("idx", "frac"), or just "idx" if return_frac is set to False
542
+ """
543
+
544
+ arr = np.asarray(arr)
545
+ values = np.asarray(values)
546
+ n, m = arr.shape
547
+
548
+ assert (n, ) == values.shape
549
+ assert np.all(arr[:, 0] <= arr[:, -1])
550
+
551
+ idx_raw = np.sum(arr.T <= values, axis=0) - 1
552
+ idx = np.maximum(0, np.minimum(idx_raw, m - 2))
553
+
554
+ if not return_frac:
555
+ return idx
556
+
557
+ i = np.arange(n)
558
+ values_0 = arr[i, idx]
559
+ values_1 = arr[i, idx + 1]
560
+ frac_raw = (values - values_0) / (values_1 - values_0)
561
+ frac = np.maximum(0, np.minimum(frac_raw, 1))
562
+ return idx, frac
563
+
564
+
565
+ def bilinear_interp(arr: np.ndarray, y: Sequence, x: Sequence):
566
+ """
567
+ Bilinear interpolation of a multi-dimensional array
568
+
569
+ The function interpolates the input array in the second last and last
570
+ dimensions, but leaves the first dimensions unchanged.
571
+
572
+ :param arr: Input array
573
+ :param y: Fractional coordinates of the second last dimension
574
+ :param x: Fractional coordinates of the last dimension
575
+ :return:
576
+ """
577
+ nx = arr.shape[-1]
578
+ ny = arr.shape[-2]
579
+
580
+ x = np.minimum(nx - 1, np.maximum(0, x))
581
+ y = np.minimum(ny - 1, np.maximum(0, y))
582
+
583
+ x0 = np.minimum(nx - 2, np.int32(x))
584
+ y0 = np.minimum(ny - 2, np.int32(y))
585
+
586
+ xf = x - x0
587
+ yf = y - y0
588
+
589
+ z00 = arr[..., y0, x0]
590
+ z01 = arr[..., y0, x0 + 1]
591
+ z10 = arr[..., y0 + 1, x0]
592
+ z11 = arr[..., y0 + 1, x0 + 1]
593
+
594
+ z = (
595
+ z00 * (1 - xf) * (1 - yf)
596
+ + z01 * xf * (1 - yf)
597
+ + z10 * (1 - xf) * yf
598
+ + z11 * xf * yf
599
+ )
600
+ return z.T
601
+
602
+
603
+ def compute_dx_dy(lat, lon):
604
+ """
605
+ Compute scale factors and bearings from grid
606
+
607
+ The grid is assumed to be a structured grid with lat/lon coordinates
608
+ at every grid point. The function computes two variables:
609
+
610
+ dx: The distance (in meters) when moving one grid cell along the X axis
611
+ dy: The distance (in meters) when moving one grid cell along the Y axis
612
+
613
+ The shape of the returned arrays will be one less than the input arrays
614
+ in the dimension where the differential is computed.
615
+
616
+ :param lat: Latitude (in degrees) of grid points
617
+ :param lon: Longitude (in degrees) of grid points
618
+ :return: A tuple (dx, dy)
619
+ """
620
+ import pyproj
621
+
622
+ geod = pyproj.Geod(ellps='WGS84')
623
+
624
+ _, _, dist_x = geod.inv(
625
+ lons1=lon[:, :-1], lats1=lat[:, :-1],
626
+ lons2=lon[:, 1:], lats2=lat[:, 1:],
627
+ )
628
+
629
+ _, _, dist_y = geod.inv(
630
+ lons1=lon[:-1, :], lats1=lat[:-1, :],
631
+ lons2=lon[1:, :], lats2=lat[1:, :],
632
+ )
633
+
634
+ return dist_x, dist_y
ladim/gridforce/ROMS.py CHANGED
@@ -251,10 +251,11 @@ class Forcing:
251
251
 
252
252
  """
253
253
 
254
- def __init__(self, config, grid):
254
+ def __init__(self, config, _):
255
255
 
256
256
  logger.info("Initiating forcing")
257
257
 
258
+ grid = Grid(config)
258
259
  self._grid = grid # Get the grid object, make private?
259
260
  # self.config = config["gridforce"]
260
261
  self.ibm_forcing = config["ibm_forcing"]
@@ -292,7 +293,12 @@ class Forcing:
292
293
  # --------------
293
294
  # prestep = last forcing step < 0
294
295
  #
296
+ self.has_been_initialized = False
297
+ self.steps = steps
298
+ self._files = files
295
299
 
300
+ def _remaining_initialization(self):
301
+ steps = self.steps
296
302
  V = [step for step in steps if step < 0]
297
303
  if V: # Forcing available before start time
298
304
  prestep = max(V)
@@ -335,9 +341,7 @@ class Forcing:
335
341
  else:
336
342
  # No forcing at start, should already be excluded
337
343
  raise SystemExit(3)
338
-
339
- self.steps = steps
340
- self._files = files
344
+ self.has_been_initialized = True
341
345
 
342
346
  # ===================================================
343
347
  @staticmethod
@@ -435,6 +439,9 @@ class Forcing:
435
439
  def update(self, t):
436
440
  """Update the fields to time step t"""
437
441
 
442
+ if not self.has_been_initialized:
443
+ self._remaining_initialization()
444
+
438
445
  # Read from config?
439
446
  interpolate_velocity_in_time = True
440
447
  interpolate_ibm_forcing_in_time = False