ladim 2.0.9__py3-none-any.whl → 2.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ladim/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- __version__ = '2.0.9'
1
+ __version__ = '2.1.6'
2
2
 
3
3
  from .main import main, run
ladim/config.py CHANGED
@@ -82,7 +82,6 @@ def convert_1_to_2(c):
82
82
  out['solver']['stop'] = dict_get(c, 'time_control.stop_time')
83
83
  out['solver']['step'] = dt_sec
84
84
  out['solver']['seed'] = dict_get(c, 'numerics.seed')
85
- out['solver']['order'] = ['release', 'forcing', 'output', 'tracker', 'ibm', 'state']
86
85
 
87
86
  out['grid'] = {}
88
87
  out['grid']['file'] = dict_get(c, [
@@ -93,7 +92,7 @@ def convert_1_to_2(c):
93
92
  out['grid']['start_time'] = np.datetime64(dict_get(c, 'time_control.start_time', '1970'), 's')
94
93
  out['grid']['subgrid'] = dict_get(c, 'gridforce.subgrid', None)
95
94
 
96
- out['forcing'] = {}
95
+ out['forcing'] = {k: v for k, v in c.get('gridforce', {}).items() if k not in ('input_file', 'module')}
97
96
  out['forcing']['file'] = dict_get(c, ['gridforce.input_file', 'files.input_file'])
98
97
  out['forcing']['first_file'] = dict_get(c, 'gridforce.first_file', "")
99
98
  out['forcing']['last_file'] = dict_get(c, 'gridforce.last_file', "")
@@ -142,7 +141,6 @@ def convert_1_to_2(c):
142
141
 
143
142
  out['ibm'] = {}
144
143
  if 'ibm' in c:
145
- out['ibm']['module'] = 'ladim.ibms.LegacyIBM'
146
144
  out['ibm']['legacy_module'] = dict_get(c, ['ibm.ibm_module', 'ibm.module'])
147
145
  if out['ibm']['legacy_module'] == 'ladim.ibms.ibm_salmon_lice':
148
146
  out['ibm']['legacy_module'] = 'ladim_plugins.salmon_lice'
ladim/forcing.py CHANGED
@@ -1,10 +1,23 @@
1
- from .model import Model, Module
1
+ import typing
2
+ if typing.TYPE_CHECKING:
3
+ from ladim.model import Model
4
+ import numexpr
5
+ import string
6
+ import numpy as np
7
+ from numba import njit
2
8
 
3
9
 
4
- class Forcing(Module):
10
+ class Forcing:
11
+ @staticmethod
12
+ def from_roms(**conf):
13
+ return RomsForcing(**conf)
14
+
5
15
  def velocity(self, X, Y, Z, tstep=0.0):
6
16
  raise NotImplementedError
7
17
 
18
+ def update(self, model: "Model"):
19
+ raise NotImplementedError
20
+
8
21
 
9
22
  class RomsForcing(Forcing):
10
23
  def __init__(self, file, variables=None, **conf):
@@ -37,11 +50,7 @@ class RomsForcing(Forcing):
37
50
 
38
51
  grid_ref = GridReference()
39
52
  legacy_conf = dict(
40
- gridforce=dict(
41
- input_file=file,
42
- first_file=conf.get('first_file', ""),
43
- last_file=conf.get('last_file', ""),
44
- ),
53
+ gridforce=dict(input_file=file, **conf),
45
54
  ibm_forcing=conf.get('ibm_forcing', []),
46
55
  start_time=conf.get('start_time', None),
47
56
  stop_time=conf.get('stop_time', None),
@@ -50,7 +59,7 @@ class RomsForcing(Forcing):
50
59
  if conf.get('subgrid', None) is not None:
51
60
  legacy_conf['gridforce']['subgrid'] = conf['subgrid']
52
61
 
53
- from .model import load_class
62
+ from .utilities import load_class
54
63
  LegacyForcing = load_class(conf.get('legacy_module', 'ladim.gridforce.ROMS.Forcing'))
55
64
 
56
65
  # Allow gridforce module in current directory
@@ -63,7 +72,7 @@ class RomsForcing(Forcing):
63
72
  # self.U = self.forcing.U
64
73
  # self.V = self.forcing.V
65
74
 
66
- def update(self, model: Model):
75
+ def update(self, model: "Model"):
67
76
  elapsed = model.solver.time - model.solver.start
68
77
  t = elapsed // model.solver.step
69
78
 
@@ -93,3 +102,447 @@ class GridReference:
93
102
 
94
103
  def __getattr__(self, item):
95
104
  return getattr(self.modules.grid.grid, item)
105
+
106
+
107
+ def load_netcdf_chunk(url, varname, subset):
108
+ """
109
+ Download, unzip and decode a netcdf chunk from file or url
110
+ """
111
+ import xarray as xr
112
+ with xr.open_dataset(url) as dset:
113
+ values = dset.variables[varname][subset].values
114
+ if varname in ['u', 'v', 'w']:
115
+ values = np.nan_to_num(values)
116
+ return values
117
+
118
+
119
+ class ChunkCache:
120
+ """
121
+ A cache for storing and sharing chunks of data using shared memory.
122
+
123
+ This class manages a memory block divided into a header, index, and data section.
124
+ It is designed for efficient inter-process communication of chunked data arrays.
125
+
126
+ :ivar mem: SharedMemory object representing the memory block.
127
+ :ivar num_chunks: Number of slots/chunks in the cache (read-only).
128
+ :ivar chunksize: Size of each chunk (read-only).
129
+ :ivar datatype: Data type of the stored chunks (read-only).
130
+ :ivar itemsize: Size in bytes of each data item (read-only).
131
+ :ivar chunk_id: Array of chunk IDs for tracking which data is stored in each slot.
132
+ :ivar data: 2D array holding the actual chunked data.
133
+ """
134
+ def __init__(self, name: str):
135
+ """
136
+ Attach to an existing shared memory block and map the cache structure.
137
+
138
+ :param name: The name of the shared memory block to attach to.
139
+ """
140
+ from multiprocessing.shared_memory import SharedMemory
141
+ mem = SharedMemory(name=name, create=False)
142
+ self.mem = mem
143
+
144
+ # Header block
145
+ self.num_chunks = np.ndarray(shape=(), dtype=np.int64, buffer=mem.buf[0:8])
146
+ self.chunksize = np.ndarray(shape=(), dtype=np.int64, buffer=mem.buf[8:16])
147
+ self.datatype = np.ndarray(shape=(), dtype='S8', buffer=mem.buf[16:24])
148
+ self.itemsize = np.ndarray(shape=(), dtype=np.int64, buffer=mem.buf[24:32])
149
+ self.num_chunks.setflags(write=False)
150
+ self.chunksize.setflags(write=False)
151
+ self.datatype.setflags(write=False)
152
+ self.itemsize.setflags(write=False)
153
+
154
+ # LRU block
155
+ lru_start = 32
156
+ lru_stop = lru_start + 2*self.num_chunks
157
+ self.lru = np.ndarray(
158
+ shape=(self.num_chunks,),
159
+ dtype=np.int16,
160
+ buffer=mem.buf[lru_start:lru_stop])
161
+
162
+ # Index block
163
+ idx_start = lru_stop
164
+ idx_stop = idx_start + 8*self.num_chunks
165
+ self.chunk_id = np.ndarray(
166
+ shape=(self.num_chunks, ),
167
+ dtype=np.int64,
168
+ buffer=mem.buf[idx_start:idx_stop])
169
+
170
+ # Data block
171
+ dat_start = idx_stop
172
+ dat_stop = dat_start + self.num_chunks * self.chunksize * self.itemsize
173
+ self.data = np.ndarray(
174
+ shape=(self.num_chunks, self.chunksize),
175
+ dtype=self.datatype.item().decode('ascii'),
176
+ buffer=mem.buf[dat_start:dat_stop])
177
+
178
+ def _update_lru(self, slot: int) -> None:
179
+ """
180
+ Move the given slot to the front (most recently used) in the LRU table.
181
+ """
182
+ update_lru(self.lru, slot)
183
+
184
+ def read(self, slot: int) -> np.ndarray:
185
+ """
186
+ Read data from the given slot and update the LRU table.
187
+
188
+ :param slot: The slot index to read
189
+ :return: The data in the slot
190
+ """
191
+ self._update_lru(slot)
192
+ return self.data[slot, :]
193
+
194
+ def write(self, data: np.ndarray, slot: int) -> None:
195
+ """
196
+ Overwrite the data in the given slot and update the LRU table.
197
+
198
+ :param data: 1D numpy array of length self.chunksize and dtype self.datatype
199
+ :param slot: The slot index to overwrite
200
+ """
201
+ self._update_lru(slot)
202
+ self.data[slot, :] = data
203
+
204
+ def __enter__(self) -> "ChunkCache":
205
+ """
206
+ Enter the runtime context related to this object.
207
+ Returns self for use in 'with' statements.
208
+
209
+ :return: self
210
+ """
211
+ return self
212
+
213
+ def __exit__(self, type: type, value: Exception, tb: object) -> None:
214
+ """
215
+ Exit the runtime context and close the shared memory.
216
+
217
+ :param type: Exception type
218
+ :param value: Exception value
219
+ :param tb: Traceback object
220
+ """
221
+ self.close()
222
+
223
+ def __setattr__(self, name: str, value: object) -> None:
224
+ """
225
+ Prevent reassignment of attributes after initialization.
226
+ Raises AttributeError if an attribute is already set.
227
+
228
+ :param name: Attribute name
229
+ :param value: Attribute value
230
+ :raises AttributeError: If attribute is already set
231
+ """
232
+ if hasattr(self, name):
233
+ raise AttributeError(f"Cannot reassign attribute '{name}'")
234
+ super().__setattr__(name, value)
235
+
236
+ @staticmethod
237
+ def create(slots: int, chunksize: int, datatype: str = 'f4') -> "ChunkCache":
238
+ """
239
+ Create a new shared memory block and initialize a ChunkCache.
240
+
241
+ :param slots: Number of slots/chunks in the cache.
242
+ :param chunksize: Size of each chunk.
243
+ :param datatype: Numpy dtype string for the data (default 'f4').
244
+ :return: An instance attached to the new shared memory block.
245
+ """
246
+ from multiprocessing.shared_memory import SharedMemory
247
+
248
+ test_item = np.empty((), dtype=datatype)
249
+ str_dtype = str(test_item.dtype)
250
+ if len(str_dtype) > 8:
251
+ raise ValueError('Unsupported data type: {str_dtype}')
252
+
253
+ # Reserve memory space for the cache block
254
+ size_header_block = 32
255
+ size_lru_block = 2 * slots # int16
256
+ size_index_block = 8 * slots
257
+ size_data_block = slots * chunksize * test_item.itemsize
258
+ bytes = size_header_block + size_lru_block + size_index_block + size_data_block
259
+ mem = SharedMemory(create=True, size=bytes)
260
+
261
+ # Write some header information
262
+ mem_slots = np.ndarray(shape=(), dtype=np.int64, buffer=mem.buf[0:8])
263
+ mem_slots[...] = slots
264
+ mem_chunksize = np.ndarray(shape=(), dtype=np.int64, buffer=mem.buf[8:16])
265
+ mem_chunksize[...] = chunksize
266
+ mem_datatype = np.ndarray(shape=(), dtype='S8', buffer=mem.buf[16:24])
267
+ mem_datatype[...] = str_dtype
268
+ mem_itemsize = np.ndarray(shape=(), dtype=np.int64, buffer=mem.buf[24:32])
269
+ mem_itemsize[...] = test_item.itemsize
270
+
271
+ # LRU block
272
+ lru_start = size_header_block
273
+ mem_lru = np.ndarray(
274
+ shape=(slots,),
275
+ dtype=np.int16,
276
+ buffer=mem.buf[lru_start:lru_start + size_lru_block])
277
+ mem_lru[...] = np.arange(slots, dtype=np.int16)
278
+
279
+ # Index block
280
+ index_start = lru_start + size_lru_block
281
+ mem_chunkid = np.ndarray(
282
+ shape=(slots, ),
283
+ dtype=np.int64,
284
+ buffer=mem.buf[index_start:index_start + size_index_block])
285
+ mem_chunkid[...] = -1
286
+
287
+ # Data block
288
+ # (no need to initialize, will be written on use)
289
+ return ChunkCache(mem.name)
290
+
291
+
292
+ def close(self) -> None:
293
+ """
294
+ Close the shared memory block.
295
+ """
296
+ self.mem.close()
297
+
298
+ def contains(self, id: int) -> bool:
299
+ """
300
+ Check if the cache contains a chunk with the given id.
301
+
302
+ :param id: The chunk id to check
303
+ :return: True if the chunk is in the cache, False otherwise
304
+ """
305
+ return indexof(self.chunk_id, id) >= 0
306
+
307
+ def push(self, data: np.ndarray, id: int) -> None:
308
+ """
309
+ Push a chunk of data into the cache with the given id.
310
+
311
+ :param data: 1D numpy array of length self.chunksize and dtype self.datatype
312
+ :param id: The chunk id to associate with this data
313
+ :note: If no free slots are available, evict the least recently used slot.
314
+ """
315
+ free_slots = np.where(self.chunk_id == -1)[0]
316
+ if len(free_slots) > 0:
317
+ slot = free_slots[0]
318
+ else:
319
+ # Evict the least recently used slot (last in lru)
320
+ slot = self.lru[-1]
321
+ self.write(data, slot)
322
+ self.chunk_id[slot] = id
323
+
324
+ def pull(self, id: int) -> np.ndarray:
325
+ """
326
+ Retrieve the data for the given chunk id and update the LRU table.
327
+
328
+ :param id: The chunk id to retrieve
329
+ :return: The data array for the chunk
330
+ :raises KeyError: If the chunk id is not found in the cache
331
+ """
332
+ slot = indexof(self.chunk_id, id)
333
+ if slot < 0:
334
+ raise KeyError(f"Chunk id {id} not found in cache")
335
+ return self.read(slot)
336
+
337
+
338
+ def timestring_formatter(pattern, time):
339
+ """
340
+ Format a time string
341
+
342
+ :param pattern: f-string style formatting pattern
343
+ :param time: Numpy convertible time specification
344
+ :returns: A formatted time string
345
+ """
346
+ posix_time = np.datetime64(time, 's').astype(int)
347
+
348
+ class PosixFormatter(string.Formatter):
349
+ def get_value(self, key: int | str, args: typing.Sequence[typing.Any], kwargs: typing.Mapping[str, typing.Any]) -> typing.Any:
350
+ return numexpr.evaluate(
351
+ key, local_dict=kwargs, global_dict=dict())
352
+
353
+ def format_field(self, value: typing.Any, format_spec: str) -> typing.Any:
354
+ dt = np.int64(value).astype('datetime64[s]').astype(object)
355
+ return dt.strftime(format_spec)
356
+
357
+ fmt = PosixFormatter()
358
+ return fmt.format(pattern, time=posix_time)
359
+
360
+
361
+ @njit
362
+ def update_lru(lru: np.ndarray, slot: int) -> None:
363
+ """
364
+ Update the LRU (Least Recently Used) list by moving the specified slot to the front.
365
+
366
+ :param lru: The LRU array
367
+ :param slot: The slot index to move to the front
368
+ """
369
+ v = slot
370
+ for i in range(len(lru)):
371
+ u = lru[i]
372
+ lru[i] = v
373
+ if u == slot:
374
+ break
375
+ v = u
376
+
377
+
378
+ @njit
379
+ def indexof(array: np.ndarray, value: int) -> int:
380
+ """
381
+ Find the index of the first occurrence of a value in an array.
382
+
383
+ :param array: The input array
384
+ :param value: The value to find
385
+ :return: The index of the first occurrence, or -1 if not found
386
+ """
387
+ for i in range(len(array)):
388
+ if array[i] == value:
389
+ return i
390
+ return -1
391
+
392
+
393
+ @njit(inline="always", fastmath=True)
394
+ def get_chunk_id(i, j, l, size_x, size_y, num_x, num_y):
395
+ """
396
+ Calculate the chunk ID based on the indices and sizes.
397
+
398
+ We assume that the array is chunked in the x and y dimensions,
399
+ but not in the z dimension. The t dimension is assumed to be
400
+ chunked with size_t = 1.
401
+
402
+ For instance, if the chunk size is x:10, y:5 and the number of chunks
403
+ in the x and y dimensions is 6 and 7 respectively, then the chunk
404
+ ID for the coordinates (31, 14, 0) would be calculated as follows:
405
+
406
+ chunk id in x direction: 31 // 10 = 3
407
+ chunk id in y direction: 14 // 5 = 2
408
+ chunk id in t direction: 0 // 1 = 0
409
+ chunk id = 3 + 6*2 + 6*7*0 = 15
410
+
411
+ This means that the chunk ID is a unique identifier for the chunk
412
+ containing the coordinates (31, 14, 0) in the array.
413
+
414
+ :param i: Integer x coordinate (global index)
415
+ :param j: Integer y coordinate (global index)
416
+ :param l: Integer t coordinate (global index)
417
+ :param size_x: Chunk size in x dimension
418
+ :param size_y: Chunk size in y dimension
419
+ :param num_x: Number of chunks in x dimension
420
+ :param num_y: Number of chunks in y dimension
421
+ :return: Chunk ID
422
+ """
423
+
424
+ # The global index is divided by the chunk size to get the chunk ID
425
+ # for each dimension. The chunk ID is then combined into a single
426
+ # integer value.
427
+ return (i // size_x) + num_x * ((j // size_y) + num_y * l)
428
+
429
+
430
+ @njit(inline="always", fastmath=True)
431
+ def interp_xyzt(x, y, z, t, v, ncx, ncy, ncz, nct, cache, ids):
432
+ """
433
+ Interpolate the data in the x, y, z, and t dimensions.
434
+
435
+ :param x: x coordinate (global index)
436
+ :param y: y coordinate (global index)
437
+ :param z: z coordinate (global index)
438
+ :param t: t coordinate (global index)
439
+ :param v: v coordinate (global index)
440
+ :param ncx: Number of chunks in x dimension
441
+ :param ncy: Number of chunks in y dimension
442
+ :param ncz: Number of chunks in z dimension
443
+ :param nct: Number of chunks in t dimension
444
+ :param cache: Chunk cache array
445
+ :param ids: Array of chunk ids for each slot in the cache
446
+ :return: Interpolated value
447
+ """
448
+ _, _, st, sz, sy, sx = cache.shape
449
+
450
+ max_x = ncx * sx
451
+ max_y = ncy * sy
452
+ max_z = ncz * sz
453
+ max_t = nct * st
454
+
455
+ i0 = max(0, min(max_x - 2, np.int32(x)))
456
+ j0 = max(0, min(max_y - 2, np.int32(y)))
457
+ k0 = max(0, min(max_z - 2, np.int32(z)))
458
+ l0 = max(0, min(max_t - 2, np.int32(t)))
459
+
460
+ i1 = i0 + 1
461
+ j1 = j0 + 1
462
+ k1 = k0 + 1
463
+ l1 = l0 + 1
464
+
465
+ # Chunk ID (chid) for the surrounding points
466
+ chid0000 = get_chunk_id(i0, j0, l0, sx, sy, ncx, ncy)
467
+ chid0001 = get_chunk_id(i1, j0, l0, sx, sy, ncx, ncy)
468
+ chid0010 = get_chunk_id(i0, j1, l0, sx, sy, ncx, ncy)
469
+ chid0011 = get_chunk_id(i1, j1, l0, sx, sy, ncx, ncy)
470
+ chid1000 = get_chunk_id(i0, j0, l1, sx, sy, ncx, ncy)
471
+ chid1001 = get_chunk_id(i1, j0, l1, sx, sy, ncx, ncy)
472
+ chid1010 = get_chunk_id(i0, j1, l1, sx, sy, ncx, ncy)
473
+ chid1011 = get_chunk_id(i1, j1, l1, sx, sy, ncx, ncy)
474
+
475
+ # Memory offset into cache for each chunk
476
+ slot0000 = indexof(ids, chid0000)
477
+ slot0001 = indexof(ids, chid0001)
478
+ slot0010 = indexof(ids, chid0010)
479
+ slot0011 = indexof(ids, chid0011)
480
+ slot1000 = indexof(ids, chid1000)
481
+ slot1001 = indexof(ids, chid1001)
482
+ slot1010 = indexof(ids, chid1010)
483
+ slot1011 = indexof(ids, chid1011)
484
+
485
+ # Return nan if any of the slots are not found
486
+ if (slot0000 < 0 or slot0001 < 0 or slot0010 < 0 or slot0011 < 0 or
487
+ slot1000 < 0 or slot1001 < 0 or slot1010 < 0 or slot1011 < 0):
488
+ return np.nan
489
+
490
+ # Within-chunk indices
491
+ ii0 = i0 % sx
492
+ ii1 = i1 % sx
493
+ jj0 = j0 % sy
494
+ jj1 = j1 % sy
495
+ kk0 = k0 % sz
496
+ kk1 = k1 % sz
497
+ ll0 = l0 % st
498
+ ll1 = l1 % st
499
+
500
+ # Memory extraction
501
+ u0000 = cache[slot0000, v, ll0, kk0, jj0, ii0]
502
+ u0001 = cache[slot0001, v, ll0, kk0, jj0, ii1]
503
+ u0010 = cache[slot0010, v, ll0, kk1, jj0, ii0]
504
+ u0011 = cache[slot0011, v, ll0, kk1, jj0, ii1]
505
+ u0100 = cache[slot0000, v, ll1, kk0, jj1, ii0]
506
+ u0101 = cache[slot0001, v, ll1, kk0, jj1, ii1]
507
+ u0110 = cache[slot0010, v, ll1, kk1, jj1, ii0]
508
+ u0111 = cache[slot0011, v, ll1, kk1, jj1, ii1]
509
+ u1000 = cache[slot1000, v, ll0, kk0, jj0, ii0]
510
+ u1001 = cache[slot1001, v, ll0, kk0, jj0, ii1]
511
+ u1010 = cache[slot1010, v, ll0, kk1, jj0, ii0]
512
+ u1011 = cache[slot1011, v, ll0, kk1, jj0, ii1]
513
+ u1100 = cache[slot1000, v, ll1, kk0, jj1, ii0]
514
+ u1101 = cache[slot1001, v, ll1, kk0, jj1, ii1]
515
+ u1110 = cache[slot1010, v, ll1, kk1, jj1, ii0]
516
+ u1111 = cache[slot1011, v, ll1, kk1, jj1, ii1]
517
+
518
+ # Interpolation weights
519
+ # The weights are calculated as the distance from the lower bound
520
+ p = x - i0
521
+ q = y - j0
522
+ r = z - k0
523
+ s = t - l0
524
+
525
+ w0000 = (1 - s) * (1 - r) * (1 - q) * (1 - p)
526
+ w0001 = (1 - s) * (1 - r) * (1 - q) * p
527
+ w0010 = (1 - s) * (1 - r) * q * (1 - p)
528
+ w0011 = (1 - s) * (1 - r) * q * p
529
+ w0100 = (1 - s) * r * (1 - q) * (1 - p)
530
+ w0101 = (1 - s) * r * (1 - q) * p
531
+ w0110 = (1 - s) * r * q * (1 - p)
532
+ w0111 = (1 - s) * r * q * p
533
+ w1000 = s * (1 - r) * (1 - q) * (1 - p)
534
+ w1001 = s * (1 - r) * (1 - q) * p
535
+ w1010 = s * (1 - r) * q * (1 - p)
536
+ w1011 = s * (1 - r) * q * p
537
+ w1100 = s * r * (1 - q) * (1 - p)
538
+ w1101 = s * r * (1 - q) * p
539
+ w1110 = s * r * q * (1 - p)
540
+ w1111 = s * r * q * p
541
+
542
+ # Interpolation
543
+ result = (w0000 * u0000 + w0001 * u0001 + w0010 * u0010 + w0011 * u0011 +
544
+ w0100 * u0100 + w0101 * u0101 + w0110 * u0110 + w0111 * u0111 +
545
+ w1000 * u1000 + w1001 * u1001 + w1010 * u1010 + w1011 * u1011 +
546
+ w1100 * u1100 + w1101 * u1101 + w1110 * u1110 + w1111 * u1111)
547
+
548
+ return result
ladim/grid.py CHANGED
@@ -1,16 +1,19 @@
1
- from .model import Module
2
1
  import numpy as np
3
2
  from typing import Sequence
4
3
  from scipy.ndimage import map_coordinates
5
4
 
6
5
 
7
- class Grid(Module):
6
+ class Grid:
8
7
  """
9
8
  The grid class represents the coordinate system used for particle tracking.
10
9
  It contains methods for converting between global coordinates (latitude,
11
10
  longitude, depth and posix time) and internal coordinates.
12
11
  """
13
12
 
13
+ @staticmethod
14
+ def from_roms(**conf):
15
+ return RomsGrid(**conf)
16
+
14
17
  def ingrid(self, X, Y):
15
18
  raise NotImplementedError
16
19
 
@@ -188,7 +191,7 @@ class RomsGrid(Grid):
188
191
  if subgrid is not None:
189
192
  legacy_conf['gridforce']['subgrid'] = subgrid
190
193
 
191
- from .model import load_class
194
+ from .utilities import load_class
192
195
  LegacyGrid = load_class(legacy_module)
193
196
 
194
197
  # Allow gridforce module in current directory
ladim/gridforce/ROMS.py CHANGED
@@ -62,7 +62,7 @@ class Grid:
62
62
  # Here, imax, jmax refers to whole grid
63
63
  jmax, imax = ncid.variables["h"].shape
64
64
  whole_grid = [1, imax - 1, 1, jmax - 1]
65
- if "subgrid" in config["gridforce"]:
65
+ if config["gridforce"].get('subgrid', None):
66
66
  limits = list(config["gridforce"]["subgrid"])
67
67
  else:
68
68
  limits = whole_grid
ladim/ibms/__init__.py CHANGED
@@ -1,18 +1,22 @@
1
- from ..model import Model, Module
2
1
  import numpy as np
2
+ import typing
3
3
 
4
+ if typing.TYPE_CHECKING:
5
+ from ..model import Model
4
6
 
5
- class IBM(Module):
6
- pass
7
7
 
8
+ class IBM:
9
+ def __init__(self, legacy_module=None, conf: dict = None):
10
+ from ..utilities import load_class
8
11
 
9
- class LegacyIBM(IBM):
10
- def __init__(self, legacy_module, conf):
11
- from ..model import load_class
12
- LegacyIbmClass = load_class(legacy_module + '.IBM')
13
- self._ibm = LegacyIbmClass(conf)
12
+ if legacy_module is None:
13
+ UserIbmClass = EmptyIBM
14
+ else:
15
+ UserIbmClass = load_class(legacy_module + '.IBM')
14
16
 
15
- def update(self, model: Model):
17
+ self.user_ibm = UserIbmClass(conf or {})
18
+
19
+ def update(self, model: "Model"):
16
20
  grid = model.grid
17
21
  state = model.state
18
22
 
@@ -23,4 +27,12 @@ class LegacyIBM(IBM):
23
27
  )
24
28
 
25
29
  forcing = model.forcing
26
- self._ibm.update_ibm(grid, state, forcing)
30
+ self.user_ibm.update_ibm(grid, state, forcing)
31
+
32
+
33
+ class EmptyIBM:
34
+ def __init__(self, _):
35
+ pass
36
+
37
+ def update_ibm(self, grid, state, forcing):
38
+ return