ladim 2.0.9__py3-none-any.whl → 2.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ladim/__init__.py +1 -1
- ladim/config.py +1 -3
- ladim/forcing.py +462 -9
- ladim/grid.py +6 -3
- ladim/gridforce/ROMS.py +1 -1
- ladim/ibms/__init__.py +22 -10
- ladim/model.py +21 -86
- ladim/output.py +5 -7
- ladim/release.py +268 -130
- ladim/solver.py +12 -8
- ladim/state.py +20 -46
- ladim/tracker.py +13 -11
- ladim/utilities.py +28 -0
- {ladim-2.0.9.dist-info → ladim-2.1.6.dist-info}/METADATA +5 -2
- ladim-2.1.6.dist-info/RECORD +32 -0
- {ladim-2.0.9.dist-info → ladim-2.1.6.dist-info}/WHEEL +1 -1
- ladim-2.0.9.dist-info/RECORD +0 -32
- {ladim-2.0.9.dist-info → ladim-2.1.6.dist-info}/entry_points.txt +0 -0
- {ladim-2.0.9.dist-info → ladim-2.1.6.dist-info/licenses}/LICENSE +0 -0
- {ladim-2.0.9.dist-info → ladim-2.1.6.dist-info}/top_level.txt +0 -0
ladim/model.py
CHANGED
|
@@ -1,29 +1,11 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from ladim.forcing import Forcing
|
|
10
|
-
from ladim.ibms import IBM
|
|
11
|
-
from ladim.output import Output
|
|
12
|
-
from ladim.release import Releaser
|
|
13
|
-
from ladim.state import State
|
|
14
|
-
from ladim.tracker import Tracker
|
|
15
|
-
from ladim.solver import Solver
|
|
16
|
-
|
|
17
|
-
DEFAULT_MODULES = dict(
|
|
18
|
-
grid='ladim.grid.RomsGrid',
|
|
19
|
-
forcing='ladim.forcing.RomsForcing',
|
|
20
|
-
release='ladim.release.TextFileReleaser',
|
|
21
|
-
state='ladim.state.DynamicState',
|
|
22
|
-
output='ladim.output.RaggedOutput',
|
|
23
|
-
ibm='ladim.ibms.IBM',
|
|
24
|
-
tracker='ladim.tracker.HorizontalTracker',
|
|
25
|
-
solver='ladim.solver.Solver',
|
|
26
|
-
)
|
|
1
|
+
from ladim.ibms import IBM
|
|
2
|
+
from ladim.solver import Solver
|
|
3
|
+
from ladim.release import Releaser
|
|
4
|
+
from ladim.grid import Grid
|
|
5
|
+
from ladim.forcing import Forcing
|
|
6
|
+
from ladim.state import State
|
|
7
|
+
from ladim.tracker import Tracker
|
|
8
|
+
from ladim.output import Output
|
|
27
9
|
|
|
28
10
|
|
|
29
11
|
class Model:
|
|
@@ -58,25 +40,22 @@ class Model:
|
|
|
58
40
|
:return: An initialized Model class
|
|
59
41
|
"""
|
|
60
42
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
return {k: v for k, v in d.items() if k != 'module'}
|
|
43
|
+
grid = Grid.from_roms(**config['grid'])
|
|
44
|
+
forcing = Forcing.from_roms(**config['forcing'])
|
|
64
45
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
'grid', 'forcing', 'release', 'state', 'output', 'ibm', 'tracker',
|
|
68
|
-
'solver',
|
|
46
|
+
release = Releaser.from_textfile(
|
|
47
|
+
lonlat_converter=grid.ll2xy, **config['release']
|
|
69
48
|
)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
module=subconf.get('module', DEFAULT_MODULES[name]),
|
|
76
|
-
)
|
|
49
|
+
tracker = Tracker.from_config(**config['tracker'])
|
|
50
|
+
|
|
51
|
+
output = Output(**config['output'])
|
|
52
|
+
ibm = IBM(**config['ibm'])
|
|
53
|
+
solver = Solver(**config['solver'])
|
|
77
54
|
|
|
78
|
-
|
|
79
|
-
|
|
55
|
+
state = State()
|
|
56
|
+
|
|
57
|
+
# noinspection PyTypeChecker
|
|
58
|
+
return Model(grid, forcing, release, state, output, ibm, tracker, solver)
|
|
80
59
|
|
|
81
60
|
@property
|
|
82
61
|
def modules(self) -> dict:
|
|
@@ -98,47 +77,3 @@ class Model:
|
|
|
98
77
|
for m in self.modules.values():
|
|
99
78
|
if hasattr(m, 'close') and callable(m.close):
|
|
100
79
|
m.close()
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def load_class(name):
|
|
104
|
-
pkg, cls = name.rsplit(sep='.', maxsplit=1)
|
|
105
|
-
|
|
106
|
-
# Check if "pkg" is an existing file
|
|
107
|
-
spec = None
|
|
108
|
-
module_name = None
|
|
109
|
-
file_name = pkg + '.py'
|
|
110
|
-
if Path(file_name).exists():
|
|
111
|
-
# This can return None if there were import errors
|
|
112
|
-
module_name = pkg
|
|
113
|
-
spec = importlib.util.spec_from_file_location(module_name, file_name)
|
|
114
|
-
|
|
115
|
-
# If pkg can not be interpreted as a file, use regular import
|
|
116
|
-
if spec is None:
|
|
117
|
-
return getattr(importlib.import_module(pkg), cls)
|
|
118
|
-
|
|
119
|
-
# File import
|
|
120
|
-
else:
|
|
121
|
-
module = importlib.util.module_from_spec(spec)
|
|
122
|
-
sys.modules[module_name] = module
|
|
123
|
-
spec.loader.exec_module(module)
|
|
124
|
-
return getattr(module, cls)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
class Module:
|
|
128
|
-
@staticmethod
|
|
129
|
-
def from_config(conf: dict, module: str) -> "Module":
|
|
130
|
-
"""
|
|
131
|
-
Initialize a module using a configuration dict.
|
|
132
|
-
|
|
133
|
-
:param conf: The configuration parameters of the module
|
|
134
|
-
:param module: The fully qualified name of the module
|
|
135
|
-
:return: An initialized module
|
|
136
|
-
"""
|
|
137
|
-
cls = load_class(module)
|
|
138
|
-
return cls(**conf)
|
|
139
|
-
|
|
140
|
-
def update(self, model: Model):
|
|
141
|
-
pass
|
|
142
|
-
|
|
143
|
-
def close(self):
|
|
144
|
-
pass
|
ladim/output.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
|
1
|
-
from .model import Model, Module
|
|
2
1
|
import netCDF4 as nc
|
|
3
2
|
import numpy as np
|
|
3
|
+
import typing
|
|
4
|
+
if typing.TYPE_CHECKING:
|
|
5
|
+
from .model import Model
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
class Output
|
|
7
|
-
pass
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class RaggedOutput(Output):
|
|
8
|
+
class Output:
|
|
11
9
|
def __init__(self, variables: dict, file: str, frequency):
|
|
12
10
|
"""
|
|
13
11
|
Writes simulation output to netCDF file in ragged array format
|
|
@@ -52,7 +50,7 @@ class RaggedOutput(Output):
|
|
|
52
50
|
"""Returns a handle to the netCDF dataset currently being written to"""
|
|
53
51
|
return self._dset
|
|
54
52
|
|
|
55
|
-
def update(self, model: Model):
|
|
53
|
+
def update(self, model: "Model"):
|
|
56
54
|
if self._dset is None:
|
|
57
55
|
self._create_dset()
|
|
58
56
|
|
ladim/release.py
CHANGED
|
@@ -1,173 +1,110 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
|
|
3
|
-
from .model import Model, Module
|
|
4
2
|
import numpy as np
|
|
5
3
|
import pandas as pd
|
|
6
4
|
from .utilities import read_timedelta
|
|
7
5
|
import logging
|
|
6
|
+
import typing
|
|
8
7
|
|
|
8
|
+
if typing.TYPE_CHECKING:
|
|
9
|
+
from ladim.model import Model
|
|
9
10
|
|
|
10
|
-
logger = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
12
13
|
|
|
13
|
-
class Releaser(Module):
|
|
14
|
-
pass
|
|
15
14
|
|
|
15
|
+
class Releaser:
|
|
16
|
+
def __init__(self, particle_generator: typing.Callable[[float, float], pd.DataFrame]):
|
|
17
|
+
self.particle_generator = particle_generator
|
|
16
18
|
|
|
17
|
-
|
|
18
|
-
def
|
|
19
|
-
|
|
20
|
-
frequency=(0, 's'), defaults=None,
|
|
19
|
+
@staticmethod
|
|
20
|
+
def from_textfile(
|
|
21
|
+
file, colnames: list = None, formats: dict = None,
|
|
22
|
+
frequency=(0, 's'), defaults=None, lonlat_converter=None,
|
|
21
23
|
):
|
|
22
24
|
"""
|
|
23
25
|
Release module which reads from a text file
|
|
24
26
|
|
|
25
27
|
The text file must be a whitespace-separated csv file
|
|
26
28
|
|
|
29
|
+
:param lonlat_converter: Function that converts lon, lat coordinates to
|
|
30
|
+
x, y coordinates
|
|
31
|
+
|
|
27
32
|
:param file: Release file
|
|
28
33
|
|
|
29
34
|
:param colnames: Column names, if the release file does not contain any
|
|
30
35
|
|
|
31
36
|
:param formats: Data column formats, one dict entry per column. If any column
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
37
|
+
is missing, the default format is used. Keys should correspond to column names.
|
|
38
|
+
Values should be either ``"float"``, ``"int"`` or ``"time"``. Default value
|
|
39
|
+
is ``"float"`` for all columns except ``release_time``, which has default
|
|
40
|
+
value ``"time"``.
|
|
36
41
|
|
|
37
42
|
:param frequency: A two-element list with entries ``[value, unit]``, where
|
|
38
|
-
|
|
43
|
+
``unit`` can be any numpy-compatible timedelta unit (such as "s", "m", "h", "D").
|
|
39
44
|
|
|
40
45
|
:param defaults: A dict of variables to be added to each particle. The keys
|
|
41
46
|
are the variable names, the values are the initial values at particle
|
|
42
47
|
release.
|
|
43
48
|
"""
|
|
44
49
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
# Other parameters
|
|
57
|
-
self._defaults = defaults or dict()
|
|
58
|
-
|
|
59
|
-
def update(self, model: Model):
|
|
50
|
+
release_table = ReleaseTable.from_filename_or_stream(
|
|
51
|
+
file=file,
|
|
52
|
+
column_names=colnames,
|
|
53
|
+
column_formats=formats or dict(),
|
|
54
|
+
interval=read_timedelta(frequency) / np.timedelta64(1, 's'),
|
|
55
|
+
defaults=defaults or dict(),
|
|
56
|
+
lonlat_converter=lonlat_converter,
|
|
57
|
+
)
|
|
58
|
+
return Releaser(particle_generator=release_table.subset)
|
|
59
|
+
|
|
60
|
+
def update(self, model: "Model"):
|
|
60
61
|
self._add_new(model)
|
|
61
62
|
self._kill_old(model)
|
|
62
63
|
|
|
63
64
|
# noinspection PyMethodMayBeStatic
|
|
64
|
-
def _kill_old(self, model: Model):
|
|
65
|
+
def _kill_old(self, model: "Model"):
|
|
65
66
|
state = model.state
|
|
66
67
|
if 'alive' in state:
|
|
67
68
|
alive = state['alive']
|
|
68
69
|
alive &= model.grid.ingrid(state['X'], state['Y'])
|
|
69
70
|
state.remove(~alive)
|
|
70
71
|
|
|
71
|
-
def _add_new(self, model: Model):
|
|
72
|
+
def _add_new(self, model: "Model"):
|
|
72
73
|
# Get the portion of the release dataset that corresponds to
|
|
73
74
|
# current simulation time
|
|
74
|
-
df =
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
).copy(deep=True)
|
|
75
|
+
df = self.particle_generator(
|
|
76
|
+
model.solver.time,
|
|
77
|
+
model.solver.time + model.solver.step,
|
|
78
|
+
)
|
|
79
79
|
|
|
80
80
|
# If there are no new particles, but the state is empty, we should
|
|
81
81
|
# still initialize the state by adding the appropriate columns
|
|
82
82
|
if (len(df) == 0) and ('X' not in model.state):
|
|
83
83
|
model.state.append(df.to_dict(orient='list'))
|
|
84
|
-
self._last_release_dataframe = df
|
|
85
84
|
|
|
86
|
-
# If there are no new particles
|
|
87
|
-
|
|
88
|
-
continuous_release = bool(self._frequency)
|
|
89
|
-
if (len(df) == 0) and not continuous_release:
|
|
90
|
-
return
|
|
91
|
-
|
|
92
|
-
# If we have continuous release, but there are no new particles and
|
|
93
|
-
# the last release is recent, we are also done
|
|
94
|
-
current_time = model.solver.time
|
|
95
|
-
elapsed_since_last_write = current_time - self._last_release_time
|
|
96
|
-
last_release_is_recent = (elapsed_since_last_write < self._frequency)
|
|
97
|
-
if continuous_release and (len(df) == 0) and last_release_is_recent:
|
|
85
|
+
# If there are no new particles, we are done.
|
|
86
|
+
if len(df) == 0:
|
|
98
87
|
return
|
|
99
88
|
|
|
100
89
|
# If we are at the final time step, we should not release any more particles
|
|
101
|
-
if
|
|
90
|
+
if model.solver.time >= model.solver.stop:
|
|
102
91
|
return
|
|
103
92
|
|
|
104
|
-
# If we have continuous release, but there are no new particles and
|
|
105
|
-
# the last release is NOT recent, we should replace empty
|
|
106
|
-
# dataframe with the previously released dataframe
|
|
107
|
-
if continuous_release:
|
|
108
|
-
if (len(df) == 0) and not last_release_is_recent:
|
|
109
|
-
df = self._last_release_dataframe
|
|
110
|
-
self._last_release_dataframe = df # Update release dataframe
|
|
111
|
-
self._last_release_time = current_time
|
|
112
|
-
|
|
113
|
-
# If positions are given as lat/lon coordinates, we should convert
|
|
114
|
-
if "X" not in df.columns or "Y" not in df.columns:
|
|
115
|
-
if "lon" not in df.columns or "lat" not in df.columns:
|
|
116
|
-
logger.critical("Particle release must have position")
|
|
117
|
-
raise ValueError()
|
|
118
|
-
# else
|
|
119
|
-
X, Y = model.grid.ll2xy(df["lon"].values, df["lat"].values)
|
|
120
|
-
df.rename(columns=dict(lon="X", lat="Y"), inplace=True)
|
|
121
|
-
df["X"] = X
|
|
122
|
-
df["Y"] = Y
|
|
123
|
-
|
|
124
|
-
# Add default variables, if any
|
|
125
|
-
for k, v in self._defaults.items():
|
|
126
|
-
if k not in df:
|
|
127
|
-
df[k] = v
|
|
128
|
-
|
|
129
|
-
# Expand multiplicity variable, if any
|
|
130
|
-
if 'mult' in df:
|
|
131
|
-
df = df.loc[np.repeat(df.index, df['mult'].values.astype('i4'))]
|
|
132
|
-
df = df.reset_index(drop=True).drop(columns='mult')
|
|
133
|
-
|
|
134
93
|
# Add new particles
|
|
135
94
|
new_particles = df.to_dict(orient='list')
|
|
136
95
|
state = model.state
|
|
137
96
|
state.append(new_particles)
|
|
138
97
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
with open(file_or_buf, *args, **kwargs) as f:
|
|
147
|
-
yield f
|
|
148
|
-
|
|
149
|
-
if self._dataframe is None:
|
|
150
|
-
if isinstance(self._csv_fname, pd.DataFrame):
|
|
151
|
-
self._dataframe = self._csv_fname
|
|
152
|
-
|
|
153
|
-
else:
|
|
154
|
-
# noinspection PyArgumentList
|
|
155
|
-
with open_or_relay(self._csv_fname, 'r', encoding='utf-8') as fp:
|
|
156
|
-
self._dataframe = load_release_file(
|
|
157
|
-
stream=fp,
|
|
158
|
-
names=self._csv_column_names,
|
|
159
|
-
formats=self._csv_column_formats,
|
|
160
|
-
)
|
|
161
|
-
return self._dataframe
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
def release_data_subset(dataframe, start_time, stop_time):
|
|
165
|
-
start_idx, stop_idx = sorted_interval(
|
|
166
|
-
dataframe['release_time'].values,
|
|
167
|
-
start_time,
|
|
168
|
-
stop_time,
|
|
98
|
+
|
|
99
|
+
def release_data_subset(dataframe, start_time, stop_time, interval: typing.Any = 0):
|
|
100
|
+
events = resolve_schedule(
|
|
101
|
+
times=dataframe['release_time'].values,
|
|
102
|
+
interval=interval,
|
|
103
|
+
start_time=start_time,
|
|
104
|
+
stop_time=stop_time,
|
|
169
105
|
)
|
|
170
|
-
|
|
106
|
+
|
|
107
|
+
return dataframe.iloc[events]
|
|
171
108
|
|
|
172
109
|
|
|
173
110
|
def load_release_file(stream, names: list, formats: dict) -> pd.DataFrame:
|
|
@@ -188,25 +125,6 @@ def load_release_file(stream, names: list, formats: dict) -> pd.DataFrame:
|
|
|
188
125
|
return df
|
|
189
126
|
|
|
190
127
|
|
|
191
|
-
def sorted_interval(v, a, b):
|
|
192
|
-
"""
|
|
193
|
-
Searches for an interval in a sorted array
|
|
194
|
-
|
|
195
|
-
Returns the start (inclusive) and stop (exclusive) indices of
|
|
196
|
-
elements in *v* that are greater than or equal to *a* and
|
|
197
|
-
less than *b*. In other words, returns *start* and *stop* such
|
|
198
|
-
that v[start:stop] == v[(v >= a) & (v < b)]
|
|
199
|
-
|
|
200
|
-
:param v: Sorted input array
|
|
201
|
-
:param a: Lower bound of array values (inclusive)
|
|
202
|
-
:param b: Upper bound of array values (exclusive)
|
|
203
|
-
:returns: A tuple (start, stop) defining the output interval
|
|
204
|
-
"""
|
|
205
|
-
start = np.searchsorted(v, a, side='left')
|
|
206
|
-
stop = np.searchsorted(v, b, side='left')
|
|
207
|
-
return start, stop
|
|
208
|
-
|
|
209
|
-
|
|
210
128
|
def get_converters(varnames: list, conf: dict) -> dict:
|
|
211
129
|
"""
|
|
212
130
|
Given a list of varnames and config keywords, return a dict of converters
|
|
@@ -236,3 +154,223 @@ def get_converters(varnames: list, conf: dict) -> dict:
|
|
|
236
154
|
converters[varname] = dtype_func
|
|
237
155
|
|
|
238
156
|
return converters
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class ReleaseTable:
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
dataframe: pd.DataFrame,
|
|
163
|
+
interval: float,
|
|
164
|
+
defaults: dict[str, typing.Any],
|
|
165
|
+
lonlat_converter: typing.Callable[[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray]],
|
|
166
|
+
):
|
|
167
|
+
self.dataframe = dataframe
|
|
168
|
+
self.interval = interval
|
|
169
|
+
self.defaults = defaults
|
|
170
|
+
self.lonlat_converter = lonlat_converter
|
|
171
|
+
|
|
172
|
+
def subset(self, start_time, stop_time):
|
|
173
|
+
events = resolve_schedule(
|
|
174
|
+
times=self.dataframe['release_time'].values,
|
|
175
|
+
interval=self.interval,
|
|
176
|
+
start_time=start_time,
|
|
177
|
+
stop_time=stop_time,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
df = self.dataframe.iloc[events].copy(deep=True)
|
|
181
|
+
df = replace_lonlat_in_release_table(df, self.lonlat_converter)
|
|
182
|
+
df = add_default_variables_in_release_table(df, self.defaults)
|
|
183
|
+
df = expand_multiplicity_in_release_table(df)
|
|
184
|
+
|
|
185
|
+
return df
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def from_filename_or_stream(file, column_names, column_formats, interval, defaults, lonlat_converter):
|
|
189
|
+
with open_or_relay(file, 'r', encoding='utf-8') as fp:
|
|
190
|
+
return ReleaseTable.from_stream(
|
|
191
|
+
fp, column_names, column_formats, interval, defaults, lonlat_converter)
|
|
192
|
+
|
|
193
|
+
@staticmethod
|
|
194
|
+
def from_stream(fp, column_names, column_formats, interval, defaults, lonlat_converter):
|
|
195
|
+
df = load_release_file(stream=fp, names=column_names, formats=column_formats)
|
|
196
|
+
return ReleaseTable(df, interval, defaults, lonlat_converter)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def replace_lonlat_in_release_table(df, lonlat_converter):
|
|
200
|
+
if "lon" not in df.columns or "lat" not in df.columns:
|
|
201
|
+
return df
|
|
202
|
+
|
|
203
|
+
X, Y = lonlat_converter(df["lon"].values, df["lat"].values)
|
|
204
|
+
df_new = df.drop(columns=['X', 'Y', 'lat', 'lon'], errors='ignore')
|
|
205
|
+
df_new["X"] = X
|
|
206
|
+
df_new["Y"] = Y
|
|
207
|
+
return df_new
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def add_default_variables_in_release_table(df, defaults):
|
|
211
|
+
df_new = df.copy()
|
|
212
|
+
for k, v in defaults.items():
|
|
213
|
+
if k not in df:
|
|
214
|
+
df_new[k] = v
|
|
215
|
+
return df_new
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def expand_multiplicity_in_release_table(df):
|
|
219
|
+
if 'mult' not in df:
|
|
220
|
+
return df
|
|
221
|
+
df = df.loc[np.repeat(df.index, df['mult'].values.astype('i4'))]
|
|
222
|
+
df = df.reset_index(drop=True).drop(columns='mult')
|
|
223
|
+
return df
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def resolve_schedule(times, interval, start_time, stop_time):
|
|
227
|
+
"""
|
|
228
|
+
Convert decriptions of repeated events to actual event indices
|
|
229
|
+
|
|
230
|
+
The variable `times` specifies start time of scheduled events. Each event occurs
|
|
231
|
+
repeatedly (specified by `interval`) until there is a new scheduling time.
|
|
232
|
+
The function returns the index of all events occuring within the time span.
|
|
233
|
+
|
|
234
|
+
Example 1: times = [0, 0], interval = 2. These are 2 events (index [0, 1]),
|
|
235
|
+
occuring at times [0, 2, 4, 6, ...], starting at time = 0. The time interval
|
|
236
|
+
start_time = 0, stop_time = 6 will contain the event times 0, 2, 4. The
|
|
237
|
+
returned event indices are [0, 1, 0, 1, 0, 1].
|
|
238
|
+
|
|
239
|
+
Example 2: times = [0, 0, 3, 3, 3], interval = 2. The schedule starts with
|
|
240
|
+
2 events (index [0, 1]) occuring at time = 0. At time = 2, there are no new
|
|
241
|
+
scheduled events, and the previous events are repeated. At time = 3 there
|
|
242
|
+
are 3 new events scheduled (index [2, 3, 4]), which cancel the previous
|
|
243
|
+
events. The new events are repeated at times [3, 5, 7, ...]. The time
|
|
244
|
+
interval start_time = 0, stop_time = 7 contain the event times [0, 2, 3, 5].
|
|
245
|
+
The returned event indices are [0, 1, 0, 1, 2, 3, 4, 2, 3, 4].
|
|
246
|
+
|
|
247
|
+
:param times: Nondecreasing list of event times
|
|
248
|
+
:param interval: Maximum interval between scheduled times
|
|
249
|
+
:param start_time: Start time of schedule
|
|
250
|
+
:param stop_time: Stop time of schedule (not inclusive)
|
|
251
|
+
:return: Index of events in resolved schedule
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
sched = Schedule(times=np.asarray(times), events=np.arange(len(times)))
|
|
255
|
+
sched2 = sched.resolve(start_time, stop_time, interval)
|
|
256
|
+
return sched2.events
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class Schedule:
|
|
260
|
+
def __init__(self, times: np.ndarray, events: np.ndarray):
|
|
261
|
+
self.times = times.view()
|
|
262
|
+
self.events = events.view()
|
|
263
|
+
self.times.flags.writeable = False
|
|
264
|
+
self.events.flags.writeable = False
|
|
265
|
+
|
|
266
|
+
def valid(self):
|
|
267
|
+
return np.all(np.diff(self.times) >= 0)
|
|
268
|
+
|
|
269
|
+
def copy(self):
|
|
270
|
+
return Schedule(times=self.times.copy(), events=self.events.copy())
|
|
271
|
+
|
|
272
|
+
def append(self, other: "Schedule"):
|
|
273
|
+
return Schedule(
|
|
274
|
+
times=np.concatenate((self.times, other.times)),
|
|
275
|
+
events=np.concatenate((self.events, other.events)),
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
def extend_backwards_using_interval(self, time, interval):
|
|
279
|
+
min_time = self.times[0]
|
|
280
|
+
if min_time <= time:
|
|
281
|
+
return self
|
|
282
|
+
|
|
283
|
+
num_extensions = int(np.ceil((min_time - time) / interval))
|
|
284
|
+
new_time = min_time - num_extensions * interval
|
|
285
|
+
return self.extend_backwards(new_time)
|
|
286
|
+
|
|
287
|
+
def extend_backwards(self, new_minimum_time):
|
|
288
|
+
idx_to_be_copied = (self.times == self.times[0])
|
|
289
|
+
num_to_be_copied = np.count_nonzero(idx_to_be_copied)
|
|
290
|
+
extension = Schedule(
|
|
291
|
+
times=np.repeat(new_minimum_time, num_to_be_copied),
|
|
292
|
+
events=self.events[idx_to_be_copied],
|
|
293
|
+
)
|
|
294
|
+
return extension.append(self)
|
|
295
|
+
|
|
296
|
+
def trim_tail(self, stop_time):
|
|
297
|
+
num = np.sum(self.times < stop_time)
|
|
298
|
+
return Schedule(
|
|
299
|
+
times=self.times[:num],
|
|
300
|
+
events=self.events[:num],
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
def trim_head(self, start_time):
|
|
304
|
+
num = np.sum(self.times < start_time)
|
|
305
|
+
return Schedule(
|
|
306
|
+
times=self.times[num:],
|
|
307
|
+
events=self.events[num:],
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def rightshift_closest_time_value(self, time, interval):
|
|
311
|
+
# If interval=0 is specified, this means there is nothing to right-shift
|
|
312
|
+
if interval <= 0:
|
|
313
|
+
return self
|
|
314
|
+
|
|
315
|
+
# Find largest value that is smaller than or equal to time
|
|
316
|
+
idx_target_time = sum(self.times <= time) - 1
|
|
317
|
+
|
|
318
|
+
# If no tabulated time values are smaller than the given time, there
|
|
319
|
+
# is nothing to right-shift
|
|
320
|
+
if idx_target_time == -1:
|
|
321
|
+
return self
|
|
322
|
+
|
|
323
|
+
# Compute new value to write
|
|
324
|
+
target_time = self.times[idx_target_time]
|
|
325
|
+
num_offsets = np.ceil((time - target_time) / interval)
|
|
326
|
+
new_target_time = target_time + num_offsets * interval
|
|
327
|
+
|
|
328
|
+
# Check if the new value is larger than the next value
|
|
329
|
+
if idx_target_time + 1 < len(self.times): # If not, then there is no next value
|
|
330
|
+
next_time = self.times[idx_target_time + 1]
|
|
331
|
+
if new_target_time > next_time:
|
|
332
|
+
return self
|
|
333
|
+
|
|
334
|
+
# Change times
|
|
335
|
+
new_times = self.times.copy()
|
|
336
|
+
new_times[self.times == target_time] = new_target_time
|
|
337
|
+
return Schedule(times=new_times, events=self.events)
|
|
338
|
+
|
|
339
|
+
def expand(self, interval, stop):
|
|
340
|
+
# If there are no times, there should be no expansion
|
|
341
|
+
# Also, interval = 0 means no expansion
|
|
342
|
+
if (len(self.times) == 0) or (interval <= 0):
|
|
343
|
+
return self
|
|
344
|
+
|
|
345
|
+
t_unq, t_inv, t_cnt = np.unique(self.times, return_inverse=True, return_counts=True)
|
|
346
|
+
stop2 = np.maximum(stop, t_unq[-1])
|
|
347
|
+
diff = np.diff(np.concatenate((t_unq, [stop2])))
|
|
348
|
+
unq_repeats = np.ceil(diff / interval).astype(int)
|
|
349
|
+
repeats = np.repeat(unq_repeats, t_cnt)
|
|
350
|
+
|
|
351
|
+
base_times = np.repeat(self.times, repeats)
|
|
352
|
+
offsets = [i * interval for n in repeats for i in range(n)]
|
|
353
|
+
times = base_times + offsets
|
|
354
|
+
events = np.repeat(self.events, repeats)
|
|
355
|
+
|
|
356
|
+
idx = np.lexsort((events, times))
|
|
357
|
+
|
|
358
|
+
return Schedule(times=times[idx], events=events[idx])
|
|
359
|
+
|
|
360
|
+
def resolve(self, start, stop, interval):
|
|
361
|
+
s = self
|
|
362
|
+
if interval:
|
|
363
|
+
s = s.rightshift_closest_time_value(start, interval)
|
|
364
|
+
s = s.trim_head(start)
|
|
365
|
+
s = s.trim_tail(stop)
|
|
366
|
+
s = s.expand(interval, stop)
|
|
367
|
+
return s
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@contextlib.contextmanager
|
|
371
|
+
def open_or_relay(file_or_buf, *args, **kwargs):
|
|
372
|
+
if hasattr(file_or_buf, 'read'):
|
|
373
|
+
yield file_or_buf
|
|
374
|
+
else:
|
|
375
|
+
with open(file_or_buf, *args, **kwargs) as f:
|
|
376
|
+
yield f
|
ladim/solver.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from ladim.model import Model
|
|
6
|
+
|
|
3
7
|
|
|
4
8
|
class Solver:
|
|
5
|
-
def __init__(self, start, stop, step,
|
|
6
|
-
self.order = order or ('release', 'forcing', 'tracker', 'ibm', 'output')
|
|
9
|
+
def __init__(self, start, stop, step, seed=None):
|
|
7
10
|
self.start = np.datetime64(start, 's').astype('int64')
|
|
8
11
|
self.stop = np.datetime64(stop, 's').astype('int64')
|
|
9
12
|
self.step = np.timedelta64(step, 's').astype('int64')
|
|
@@ -12,12 +15,13 @@ class Solver:
|
|
|
12
15
|
if seed is not None:
|
|
13
16
|
np.random.seed(seed)
|
|
14
17
|
|
|
15
|
-
def run(self, model):
|
|
16
|
-
modules = model.modules
|
|
17
|
-
ordered_modules = [modules[k] for k in self.order if k in modules]
|
|
18
|
-
|
|
18
|
+
def run(self, model: "Model"):
|
|
19
19
|
self.time = self.start
|
|
20
20
|
while self.time <= self.stop:
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
model.release.update(model)
|
|
22
|
+
model.forcing.update(model)
|
|
23
|
+
model.output.update(model)
|
|
24
|
+
model.tracker.update(model)
|
|
25
|
+
model.ibm.update(model)
|
|
26
|
+
|
|
23
27
|
self.time += self.step
|