ladim 2.0.9__tar.gz → 2.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {ladim-2.0.9 → ladim-2.1.5}/PKG-INFO +3 -2
  2. {ladim-2.0.9 → ladim-2.1.5}/ladim/__init__.py +1 -1
  3. {ladim-2.0.9 → ladim-2.1.5}/ladim/config.py +1 -3
  4. {ladim-2.0.9 → ladim-2.1.5}/ladim/forcing.py +14 -9
  5. {ladim-2.0.9 → ladim-2.1.5}/ladim/grid.py +6 -3
  6. {ladim-2.0.9 → ladim-2.1.5}/ladim/gridforce/ROMS.py +1 -1
  7. ladim-2.1.5/ladim/ibms/__init__.py +38 -0
  8. ladim-2.1.5/ladim/model.py +79 -0
  9. {ladim-2.0.9 → ladim-2.1.5}/ladim/output.py +5 -7
  10. ladim-2.1.5/ladim/release.py +376 -0
  11. {ladim-2.0.9 → ladim-2.1.5}/ladim/solver.py +12 -8
  12. {ladim-2.0.9 → ladim-2.1.5}/ladim/state.py +20 -46
  13. {ladim-2.0.9 → ladim-2.1.5}/ladim/tracker.py +13 -11
  14. {ladim-2.0.9 → ladim-2.1.5}/ladim/utilities.py +28 -0
  15. {ladim-2.0.9 → ladim-2.1.5}/ladim.egg-info/PKG-INFO +3 -2
  16. {ladim-2.0.9 → ladim-2.1.5}/ladim.egg-info/SOURCES.txt +2 -2
  17. {ladim-2.0.9 → ladim-2.1.5}/tests/test_config.py +4 -2
  18. {ladim-2.0.9 → ladim-2.1.5}/tests/test_output.py +5 -5
  19. ladim-2.1.5/tests/test_release.py +313 -0
  20. ladim-2.1.5/tests/test_solver.py +0 -0
  21. ladim-2.0.9/tests/test_model.py → ladim-2.1.5/tests/test_utilities.py +3 -4
  22. ladim-2.0.9/ladim/ibms/__init__.py +0 -26
  23. ladim-2.0.9/ladim/model.py +0 -144
  24. ladim-2.0.9/ladim/release.py +0 -238
  25. ladim-2.0.9/tests/test_release.py +0 -161
  26. ladim-2.0.9/tests/test_solver.py +0 -60
  27. {ladim-2.0.9 → ladim-2.1.5}/LICENSE +0 -0
  28. {ladim-2.0.9 → ladim-2.1.5}/README.md +0 -0
  29. {ladim-2.0.9 → ladim-2.1.5}/ladim/__main__.py +0 -0
  30. {ladim-2.0.9 → ladim-2.1.5}/ladim/gridforce/__init__.py +0 -0
  31. {ladim-2.0.9 → ladim-2.1.5}/ladim/gridforce/analytical.py +0 -0
  32. {ladim-2.0.9 → ladim-2.1.5}/ladim/gridforce/zROMS.py +0 -0
  33. {ladim-2.0.9 → ladim-2.1.5}/ladim/ibms/light.py +0 -0
  34. {ladim-2.0.9 → ladim-2.1.5}/ladim/main.py +0 -0
  35. {ladim-2.0.9 → ladim-2.1.5}/ladim/plugins/__init__.py +0 -0
  36. {ladim-2.0.9 → ladim-2.1.5}/ladim/sample.py +0 -0
  37. {ladim-2.0.9 → ladim-2.1.5}/ladim.egg-info/dependency_links.txt +0 -0
  38. {ladim-2.0.9 → ladim-2.1.5}/ladim.egg-info/entry_points.txt +0 -0
  39. {ladim-2.0.9 → ladim-2.1.5}/ladim.egg-info/requires.txt +0 -0
  40. {ladim-2.0.9 → ladim-2.1.5}/ladim.egg-info/top_level.txt +0 -0
  41. {ladim-2.0.9 → ladim-2.1.5}/postladim/__init__.py +0 -0
  42. {ladim-2.0.9 → ladim-2.1.5}/postladim/cellcount.py +0 -0
  43. {ladim-2.0.9 → ladim-2.1.5}/postladim/kde_plot.py +0 -0
  44. {ladim-2.0.9 → ladim-2.1.5}/postladim/particlefile.py +0 -0
  45. {ladim-2.0.9 → ladim-2.1.5}/postladim/variable.py +0 -0
  46. {ladim-2.0.9 → ladim-2.1.5}/pyproject.toml +0 -0
  47. {ladim-2.0.9 → ladim-2.1.5}/setup.cfg +0 -0
  48. {ladim-2.0.9 → ladim-2.1.5}/tests/test_forcing.py +0 -0
  49. {ladim-2.0.9 → ladim-2.1.5}/tests/test_grid.py +0 -0
  50. {ladim-2.0.9 → ladim-2.1.5}/tests/test_ladim.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: ladim
3
- Version: 2.0.9
3
+ Version: 2.1.5
4
4
  Summary: Lagrangian Advection and Diffusion Model
5
5
  Home-page: https://github.com/pnsaevik/ladim
6
6
  Author: Bjørn Ådlandsvik
@@ -25,6 +25,7 @@ Requires-Dist: pyproj
25
25
  Requires-Dist: pyyaml
26
26
  Requires-Dist: scipy
27
27
  Requires-Dist: xarray
28
+ Dynamic: license-file
28
29
 
29
30
  LADiM – the Lagrangian Advection and Diffusion Model
30
31
  ====================================================
@@ -1,3 +1,3 @@
1
- __version__ = '2.0.9'
1
+ __version__ = '2.1.5'
2
2
 
3
3
  from .main import main, run
@@ -82,7 +82,6 @@ def convert_1_to_2(c):
82
82
  out['solver']['stop'] = dict_get(c, 'time_control.stop_time')
83
83
  out['solver']['step'] = dt_sec
84
84
  out['solver']['seed'] = dict_get(c, 'numerics.seed')
85
- out['solver']['order'] = ['release', 'forcing', 'output', 'tracker', 'ibm', 'state']
86
85
 
87
86
  out['grid'] = {}
88
87
  out['grid']['file'] = dict_get(c, [
@@ -93,7 +92,7 @@ def convert_1_to_2(c):
93
92
  out['grid']['start_time'] = np.datetime64(dict_get(c, 'time_control.start_time', '1970'), 's')
94
93
  out['grid']['subgrid'] = dict_get(c, 'gridforce.subgrid', None)
95
94
 
96
- out['forcing'] = {}
95
+ out['forcing'] = {k: v for k, v in c.get('gridforce', {}).items() if k not in ('input_file', 'module')}
97
96
  out['forcing']['file'] = dict_get(c, ['gridforce.input_file', 'files.input_file'])
98
97
  out['forcing']['first_file'] = dict_get(c, 'gridforce.first_file', "")
99
98
  out['forcing']['last_file'] = dict_get(c, 'gridforce.last_file', "")
@@ -142,7 +141,6 @@ def convert_1_to_2(c):
142
141
 
143
142
  out['ibm'] = {}
144
143
  if 'ibm' in c:
145
- out['ibm']['module'] = 'ladim.ibms.LegacyIBM'
146
144
  out['ibm']['legacy_module'] = dict_get(c, ['ibm.ibm_module', 'ibm.module'])
147
145
  if out['ibm']['legacy_module'] == 'ladim.ibms.ibm_salmon_lice':
148
146
  out['ibm']['legacy_module'] = 'ladim_plugins.salmon_lice'
@@ -1,10 +1,19 @@
1
- from .model import Model, Module
1
+ import typing
2
+ if typing.TYPE_CHECKING:
3
+ from ladim.model import Model
2
4
 
3
5
 
4
- class Forcing(Module):
6
+ class Forcing:
7
+ @staticmethod
8
+ def from_roms(**conf):
9
+ return RomsForcing(**conf)
10
+
5
11
  def velocity(self, X, Y, Z, tstep=0.0):
6
12
  raise NotImplementedError
7
13
 
14
+ def update(self, model: "Model"):
15
+ raise NotImplementedError
16
+
8
17
 
9
18
  class RomsForcing(Forcing):
10
19
  def __init__(self, file, variables=None, **conf):
@@ -37,11 +46,7 @@ class RomsForcing(Forcing):
37
46
 
38
47
  grid_ref = GridReference()
39
48
  legacy_conf = dict(
40
- gridforce=dict(
41
- input_file=file,
42
- first_file=conf.get('first_file', ""),
43
- last_file=conf.get('last_file', ""),
44
- ),
49
+ gridforce=dict(input_file=file, **conf),
45
50
  ibm_forcing=conf.get('ibm_forcing', []),
46
51
  start_time=conf.get('start_time', None),
47
52
  stop_time=conf.get('stop_time', None),
@@ -50,7 +55,7 @@ class RomsForcing(Forcing):
50
55
  if conf.get('subgrid', None) is not None:
51
56
  legacy_conf['gridforce']['subgrid'] = conf['subgrid']
52
57
 
53
- from .model import load_class
58
+ from .utilities import load_class
54
59
  LegacyForcing = load_class(conf.get('legacy_module', 'ladim.gridforce.ROMS.Forcing'))
55
60
 
56
61
  # Allow gridforce module in current directory
@@ -63,7 +68,7 @@ class RomsForcing(Forcing):
63
68
  # self.U = self.forcing.U
64
69
  # self.V = self.forcing.V
65
70
 
66
- def update(self, model: Model):
71
+ def update(self, model: "Model"):
67
72
  elapsed = model.solver.time - model.solver.start
68
73
  t = elapsed // model.solver.step
69
74
 
@@ -1,16 +1,19 @@
1
- from .model import Module
2
1
  import numpy as np
3
2
  from typing import Sequence
4
3
  from scipy.ndimage import map_coordinates
5
4
 
6
5
 
7
- class Grid(Module):
6
+ class Grid:
8
7
  """
9
8
  The grid class represents the coordinate system used for particle tracking.
10
9
  It contains methods for converting between global coordinates (latitude,
11
10
  longitude, depth and posix time) and internal coordinates.
12
11
  """
13
12
 
13
+ @staticmethod
14
+ def from_roms(**conf):
15
+ return RomsGrid(**conf)
16
+
14
17
  def ingrid(self, X, Y):
15
18
  raise NotImplementedError
16
19
 
@@ -188,7 +191,7 @@ class RomsGrid(Grid):
188
191
  if subgrid is not None:
189
192
  legacy_conf['gridforce']['subgrid'] = subgrid
190
193
 
191
- from .model import load_class
194
+ from .utilities import load_class
192
195
  LegacyGrid = load_class(legacy_module)
193
196
 
194
197
  # Allow gridforce module in current directory
@@ -62,7 +62,7 @@ class Grid:
62
62
  # Here, imax, jmax refers to whole grid
63
63
  jmax, imax = ncid.variables["h"].shape
64
64
  whole_grid = [1, imax - 1, 1, jmax - 1]
65
- if "subgrid" in config["gridforce"]:
65
+ if config["gridforce"].get('subgrid', None):
66
66
  limits = list(config["gridforce"]["subgrid"])
67
67
  else:
68
68
  limits = whole_grid
@@ -0,0 +1,38 @@
1
+ import numpy as np
2
+ import typing
3
+
4
+ if typing.TYPE_CHECKING:
5
+ from ..model import Model
6
+
7
+
8
+ class IBM:
9
+ def __init__(self, legacy_module=None, conf: dict = None):
10
+ from ..utilities import load_class
11
+
12
+ if legacy_module is None:
13
+ UserIbmClass = EmptyIBM
14
+ else:
15
+ UserIbmClass = load_class(legacy_module + '.IBM')
16
+
17
+ self.user_ibm = UserIbmClass(conf or {})
18
+
19
+ def update(self, model: "Model"):
20
+ grid = model.grid
21
+ state = model.state
22
+
23
+ state.dt = model.solver.step
24
+ state.timestamp = np.int64(model.solver.time).astype('datetime64[s]')
25
+ state.timestep = (
26
+ (model.solver.time - model.solver.start) // model.solver.step
27
+ )
28
+
29
+ forcing = model.forcing
30
+ self.user_ibm.update_ibm(grid, state, forcing)
31
+
32
+
33
+ class EmptyIBM:
34
+ def __init__(self, _):
35
+ pass
36
+
37
+ def update_ibm(self, grid, state, forcing):
38
+ return
@@ -0,0 +1,79 @@
1
+ from ladim.ibms import IBM
2
+ from ladim.solver import Solver
3
+ from ladim.release import Releaser
4
+ from ladim.grid import Grid
5
+ from ladim.forcing import Forcing
6
+ from ladim.state import State
7
+ from ladim.tracker import Tracker
8
+ from ladim.output import Output
9
+
10
+
11
+ class Model:
12
+ """
13
+ The Model class represents the entire simulation model. The different
14
+ submodules control the simulation behaviour. In particular, the solver
15
+ submodule controls the execution flow while the other submodules are
16
+ called once every time step within the main simulation loop.
17
+ """
18
+
19
+ def __init__(
20
+ self, grid: "Grid", forcing: "Forcing", release: "Releaser",
21
+ state: "State", output: "Output", ibm: "IBM", tracker: "Tracker",
22
+ solver: "Solver",
23
+ ):
24
+ self.grid = grid
25
+ self.forcing = forcing
26
+ self.release = release
27
+ self.state = state
28
+ self.output = output
29
+ self.ibm = ibm
30
+ self.tracker = tracker
31
+ self.solver = solver
32
+
33
+ @staticmethod
34
+ def from_config(config: dict) -> "Model":
35
+ """
36
+ Initialize a model class by supplying the configuration parameters
37
+ of each submodule.
38
+
39
+ :param config: Configuration parameters for each submodule
40
+ :return: An initialized Model class
41
+ """
42
+
43
+ grid = Grid.from_roms(**config['grid'])
44
+ forcing = Forcing.from_roms(**config['forcing'])
45
+
46
+ release = Releaser.from_textfile(
47
+ lonlat_converter=grid.ll2xy, **config['release']
48
+ )
49
+ tracker = Tracker.from_config(**config['tracker'])
50
+
51
+ output = Output(**config['output'])
52
+ ibm = IBM(**config['ibm'])
53
+ solver = Solver(**config['solver'])
54
+
55
+ state = State()
56
+
57
+ # noinspection PyTypeChecker
58
+ return Model(grid, forcing, release, state, output, ibm, tracker, solver)
59
+
60
+ @property
61
+ def modules(self) -> dict:
62
+ return dict(
63
+ grid=self.grid,
64
+ forcing=self.forcing,
65
+ release=self.release,
66
+ state=self.state,
67
+ output=self.output,
68
+ ibm=self.ibm,
69
+ tracker=self.tracker,
70
+ solver=self.solver,
71
+ )
72
+
73
+ def run(self):
74
+ self.solver.run(self)
75
+
76
+ def close(self):
77
+ for m in self.modules.values():
78
+ if hasattr(m, 'close') and callable(m.close):
79
+ m.close()
@@ -1,13 +1,11 @@
1
- from .model import Model, Module
2
1
  import netCDF4 as nc
3
2
  import numpy as np
3
+ import typing
4
+ if typing.TYPE_CHECKING:
5
+ from .model import Model
4
6
 
5
7
 
6
- class Output(Module):
7
- pass
8
-
9
-
10
- class RaggedOutput(Output):
8
+ class Output:
11
9
  def __init__(self, variables: dict, file: str, frequency):
12
10
  """
13
11
  Writes simulation output to netCDF file in ragged array format
@@ -52,7 +50,7 @@ class RaggedOutput(Output):
52
50
  """Returns a handle to the netCDF dataset currently being written to"""
53
51
  return self._dset
54
52
 
55
- def update(self, model: Model):
53
+ def update(self, model: "Model"):
56
54
  if self._dset is None:
57
55
  self._create_dset()
58
56
 
@@ -0,0 +1,376 @@
1
+ import contextlib
2
+ import numpy as np
3
+ import pandas as pd
4
+ from .utilities import read_timedelta
5
+ import logging
6
+ import typing
7
+
8
+ if typing.TYPE_CHECKING:
9
+ from ladim.model import Model
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Releaser:
16
+ def __init__(self, particle_generator: typing.Callable[[float, float], pd.DataFrame]):
17
+ self.particle_generator = particle_generator
18
+
19
+ @staticmethod
20
+ def from_textfile(
21
+ file, colnames: list = None, formats: dict = None,
22
+ frequency=(0, 's'), defaults=None, lonlat_converter=None,
23
+ ):
24
+ """
25
+ Release module which reads from a text file
26
+
27
+ The text file must be a whitespace-separated csv file
28
+
29
+ :param lonlat_converter: Function that converts lon, lat coordinates to
30
+ x, y coordinates
31
+
32
+ :param file: Release file
33
+
34
+ :param colnames: Column names, if the release file does not contain any
35
+
36
+ :param formats: Data column formats, one dict entry per column. If any column
37
+ is missing, the default format is used. Keys should correspond to column names.
38
+ Values should be either ``"float"``, ``"int"`` or ``"time"``. Default value
39
+ is ``"float"`` for all columns except ``release_time``, which has default
40
+ value ``"time"``.
41
+
42
+ :param frequency: A two-element list with entries ``[value, unit]``, where
43
+ ``unit`` can be any numpy-compatible timedelta unit (such as "s", "m", "h", "D").
44
+
45
+ :param defaults: A dict of variables to be added to each particle. The keys
46
+ are the variable names, the values are the initial values at particle
47
+ release.
48
+ """
49
+
50
+ release_table = ReleaseTable.from_filename_or_stream(
51
+ file=file,
52
+ column_names=colnames,
53
+ column_formats=formats or dict(),
54
+ interval=read_timedelta(frequency) / np.timedelta64(1, 's'),
55
+ defaults=defaults or dict(),
56
+ lonlat_converter=lonlat_converter,
57
+ )
58
+ return Releaser(particle_generator=release_table.subset)
59
+
60
+ def update(self, model: "Model"):
61
+ self._add_new(model)
62
+ self._kill_old(model)
63
+
64
+ # noinspection PyMethodMayBeStatic
65
+ def _kill_old(self, model: "Model"):
66
+ state = model.state
67
+ if 'alive' in state:
68
+ alive = state['alive']
69
+ alive &= model.grid.ingrid(state['X'], state['Y'])
70
+ state.remove(~alive)
71
+
72
+ def _add_new(self, model: "Model"):
73
+ # Get the portion of the release dataset that corresponds to
74
+ # current simulation time
75
+ df = self.particle_generator(
76
+ model.solver.time,
77
+ model.solver.time + model.solver.step,
78
+ )
79
+
80
+ # If there are no new particles, but the state is empty, we should
81
+ # still initialize the state by adding the appropriate columns
82
+ if (len(df) == 0) and ('X' not in model.state):
83
+ model.state.append(df.to_dict(orient='list'))
84
+
85
+ # If there are no new particles, we are done.
86
+ if len(df) == 0:
87
+ return
88
+
89
+ # If we are at the final time step, we should not release any more particles
90
+ if model.solver.time >= model.solver.stop:
91
+ return
92
+
93
+ # Add new particles
94
+ new_particles = df.to_dict(orient='list')
95
+ state = model.state
96
+ state.append(new_particles)
97
+
98
+
99
+ def release_data_subset(dataframe, start_time, stop_time, interval: typing.Any = 0):
100
+ events = resolve_schedule(
101
+ times=dataframe['release_time'].values,
102
+ interval=interval,
103
+ start_time=start_time,
104
+ stop_time=stop_time,
105
+ )
106
+
107
+ return dataframe.iloc[events]
108
+
109
+
110
+ def load_release_file(stream, names: list, formats: dict) -> pd.DataFrame:
111
+ if names is None:
112
+ import re
113
+ first_line = stream.readline()
114
+ names = re.split(pattern=r'\s+', string=first_line.strip())
115
+
116
+ converters = get_converters(varnames=names, conf=formats)
117
+
118
+ df = pd.read_csv(
119
+ stream,
120
+ names=names,
121
+ converters=converters,
122
+ sep='\\s+',
123
+ )
124
+ df = df.sort_values(by='release_time')
125
+ return df
126
+
127
+
128
+ def get_converters(varnames: list, conf: dict) -> dict:
129
+ """
130
+ Given a list of varnames and config keywords, return a dict of converters
131
+
132
+ Returns a dict where the keys are ``varnames`` and the values are
133
+ callables.
134
+
135
+ :param varnames: For instance, ['release_time', 'X', 'Y']
136
+ :param conf: For instance, {'release_time': 'time', 'X': 'float'}
137
+ :return: A mapping of varnames to converters
138
+ """
139
+ dtype_funcs = dict(
140
+ time=lambda item: np.datetime64(item, 's').astype('int64'),
141
+ int=int,
142
+ float=float,
143
+ )
144
+
145
+ dtype_defaults = dict(
146
+ release_time='time',
147
+ )
148
+
149
+ converters = {}
150
+ for varname in varnames:
151
+ dtype_default = dtype_defaults.get(varname, 'float')
152
+ dtype_str = conf.get(varname, dtype_default)
153
+ dtype_func = dtype_funcs[dtype_str]
154
+ converters[varname] = dtype_func
155
+
156
+ return converters
157
+
158
+
159
+ class ReleaseTable:
160
+ def __init__(
161
+ self,
162
+ dataframe: pd.DataFrame,
163
+ interval: float,
164
+ defaults: dict[str, typing.Any],
165
+ lonlat_converter: typing.Callable[[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray]],
166
+ ):
167
+ self.dataframe = dataframe
168
+ self.interval = interval
169
+ self.defaults = defaults
170
+ self.lonlat_converter = lonlat_converter
171
+
172
+ def subset(self, start_time, stop_time):
173
+ events = resolve_schedule(
174
+ times=self.dataframe['release_time'].values,
175
+ interval=self.interval,
176
+ start_time=start_time,
177
+ stop_time=stop_time,
178
+ )
179
+
180
+ df = self.dataframe.iloc[events].copy(deep=True)
181
+ df = replace_lonlat_in_release_table(df, self.lonlat_converter)
182
+ df = add_default_variables_in_release_table(df, self.defaults)
183
+ df = expand_multiplicity_in_release_table(df)
184
+
185
+ return df
186
+
187
+ @staticmethod
188
+ def from_filename_or_stream(file, column_names, column_formats, interval, defaults, lonlat_converter):
189
+ with open_or_relay(file, 'r', encoding='utf-8') as fp:
190
+ return ReleaseTable.from_stream(
191
+ fp, column_names, column_formats, interval, defaults, lonlat_converter)
192
+
193
+ @staticmethod
194
+ def from_stream(fp, column_names, column_formats, interval, defaults, lonlat_converter):
195
+ df = load_release_file(stream=fp, names=column_names, formats=column_formats)
196
+ return ReleaseTable(df, interval, defaults, lonlat_converter)
197
+
198
+
199
+ def replace_lonlat_in_release_table(df, lonlat_converter):
200
+ if "lon" not in df.columns or "lat" not in df.columns:
201
+ return df
202
+
203
+ X, Y = lonlat_converter(df["lon"].values, df["lat"].values)
204
+ df_new = df.drop(columns=['X', 'Y', 'lat', 'lon'], errors='ignore')
205
+ df_new["X"] = X
206
+ df_new["Y"] = Y
207
+ return df_new
208
+
209
+
210
+ def add_default_variables_in_release_table(df, defaults):
211
+ df_new = df.copy()
212
+ for k, v in defaults.items():
213
+ if k not in df:
214
+ df_new[k] = v
215
+ return df_new
216
+
217
+
218
+ def expand_multiplicity_in_release_table(df):
219
+ if 'mult' not in df:
220
+ return df
221
+ df = df.loc[np.repeat(df.index, df['mult'].values.astype('i4'))]
222
+ df = df.reset_index(drop=True).drop(columns='mult')
223
+ return df
224
+
225
+
226
+ def resolve_schedule(times, interval, start_time, stop_time):
227
+ """
228
+ Convert decriptions of repeated events to actual event indices
229
+
230
+ The variable `times` specifies start time of scheduled events. Each event occurs
231
+ repeatedly (specified by `interval`) until there is a new scheduling time.
232
+ The function returns the index of all events occuring within the time span.
233
+
234
+ Example 1: times = [0, 0], interval = 2. These are 2 events (index [0, 1]),
235
+ occuring at times [0, 2, 4, 6, ...], starting at time = 0. The time interval
236
+ start_time = 0, stop_time = 6 will contain the event times 0, 2, 4. The
237
+ returned event indices are [0, 1, 0, 1, 0, 1].
238
+
239
+ Example 2: times = [0, 0, 3, 3, 3], interval = 2. The schedule starts with
240
+ 2 events (index [0, 1]) occuring at time = 0. At time = 2, there are no new
241
+ scheduled events, and the previous events are repeated. At time = 3 there
242
+ are 3 new events scheduled (index [2, 3, 4]), which cancel the previous
243
+ events. The new events are repeated at times [3, 5, 7, ...]. The time
244
+ interval start_time = 0, stop_time = 7 contain the event times [0, 2, 3, 5].
245
+ The returned event indices are [0, 1, 0, 1, 2, 3, 4, 2, 3, 4].
246
+
247
+ :param times: Nondecreasing list of event times
248
+ :param interval: Maximum interval between scheduled times
249
+ :param start_time: Start time of schedule
250
+ :param stop_time: Stop time of schedule (not inclusive)
251
+ :return: Index of events in resolved schedule
252
+ """
253
+
254
+ sched = Schedule(times=np.asarray(times), events=np.arange(len(times)))
255
+ sched2 = sched.resolve(start_time, stop_time, interval)
256
+ return sched2.events
257
+
258
+
259
+ class Schedule:
260
+ def __init__(self, times: np.ndarray, events: np.ndarray):
261
+ self.times = times.view()
262
+ self.events = events.view()
263
+ self.times.flags.writeable = False
264
+ self.events.flags.writeable = False
265
+
266
+ def valid(self):
267
+ return np.all(np.diff(self.times) >= 0)
268
+
269
+ def copy(self):
270
+ return Schedule(times=self.times.copy(), events=self.events.copy())
271
+
272
+ def append(self, other: "Schedule"):
273
+ return Schedule(
274
+ times=np.concatenate((self.times, other.times)),
275
+ events=np.concatenate((self.events, other.events)),
276
+ )
277
+
278
+ def extend_backwards_using_interval(self, time, interval):
279
+ min_time = self.times[0]
280
+ if min_time <= time:
281
+ return self
282
+
283
+ num_extensions = int(np.ceil((min_time - time) / interval))
284
+ new_time = min_time - num_extensions * interval
285
+ return self.extend_backwards(new_time)
286
+
287
+ def extend_backwards(self, new_minimum_time):
288
+ idx_to_be_copied = (self.times == self.times[0])
289
+ num_to_be_copied = np.count_nonzero(idx_to_be_copied)
290
+ extension = Schedule(
291
+ times=np.repeat(new_minimum_time, num_to_be_copied),
292
+ events=self.events[idx_to_be_copied],
293
+ )
294
+ return extension.append(self)
295
+
296
+ def trim_tail(self, stop_time):
297
+ num = np.sum(self.times < stop_time)
298
+ return Schedule(
299
+ times=self.times[:num],
300
+ events=self.events[:num],
301
+ )
302
+
303
+ def trim_head(self, start_time):
304
+ num = np.sum(self.times < start_time)
305
+ return Schedule(
306
+ times=self.times[num:],
307
+ events=self.events[num:],
308
+ )
309
+
310
+ def rightshift_closest_time_value(self, time, interval):
311
+ # If interval=0 is specified, this means there is nothing to right-shift
312
+ if interval <= 0:
313
+ return self
314
+
315
+ # Find largest value that is smaller than or equal to time
316
+ idx_target_time = sum(self.times <= time) - 1
317
+
318
+ # If no tabulated time values are smaller than the given time, there
319
+ # is nothing to right-shift
320
+ if idx_target_time == -1:
321
+ return self
322
+
323
+ # Compute new value to write
324
+ target_time = self.times[idx_target_time]
325
+ num_offsets = np.ceil((time - target_time) / interval)
326
+ new_target_time = target_time + num_offsets * interval
327
+
328
+ # Check if the new value is larger than the next value
329
+ if idx_target_time + 1 < len(self.times): # If not, then there is no next value
330
+ next_time = self.times[idx_target_time + 1]
331
+ if new_target_time > next_time:
332
+ return self
333
+
334
+ # Change times
335
+ new_times = self.times.copy()
336
+ new_times[self.times == target_time] = new_target_time
337
+ return Schedule(times=new_times, events=self.events)
338
+
339
+ def expand(self, interval, stop):
340
+ # If there are no times, there should be no expansion
341
+ # Also, interval = 0 means no expansion
342
+ if (len(self.times) == 0) or (interval <= 0):
343
+ return self
344
+
345
+ t_unq, t_inv, t_cnt = np.unique(self.times, return_inverse=True, return_counts=True)
346
+ stop2 = np.maximum(stop, t_unq[-1])
347
+ diff = np.diff(np.concatenate((t_unq, [stop2])))
348
+ unq_repeats = np.ceil(diff / interval).astype(int)
349
+ repeats = np.repeat(unq_repeats, t_cnt)
350
+
351
+ base_times = np.repeat(self.times, repeats)
352
+ offsets = [i * interval for n in repeats for i in range(n)]
353
+ times = base_times + offsets
354
+ events = np.repeat(self.events, repeats)
355
+
356
+ idx = np.lexsort((events, times))
357
+
358
+ return Schedule(times=times[idx], events=events[idx])
359
+
360
+ def resolve(self, start, stop, interval):
361
+ s = self
362
+ if interval:
363
+ s = s.rightshift_closest_time_value(start, interval)
364
+ s = s.trim_head(start)
365
+ s = s.trim_tail(stop)
366
+ s = s.expand(interval, stop)
367
+ return s
368
+
369
+
370
+ @contextlib.contextmanager
371
+ def open_or_relay(file_or_buf, *args, **kwargs):
372
+ if hasattr(file_or_buf, 'read'):
373
+ yield file_or_buf
374
+ else:
375
+ with open(file_or_buf, *args, **kwargs) as f:
376
+ yield f