Simple-Track 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
simpletrack/load.py ADDED
@@ -0,0 +1,170 @@
1
+ import datetime as dt
2
+ from typing import Union
3
+
4
+ from numpy.typing import NDArray
5
+
6
+ from simpletrack.utils import check_arrays
7
+
8
+
9
+ class ConfigError(Exception):
10
+ """
11
+ Error thrown when one or more config input parameters are not valid
12
+ """
13
+
14
+
15
+ def get_loader(loader_key: str):
16
+ available_loaders = {
17
+ "MWELoader": MWELoader,
18
+ "ChilboltonLoader": ChilboltonLoader,
19
+ }
20
+ try:
21
+ loader = available_loaders[loader_key]
22
+ except KeyError as err:
23
+ raise KeyError(f"Unknown loader: {loader_key}") from err
24
+ if not issubclass(loader, BaseLoader):
25
+ raise TypeError(f"Requested loader ({loader}) is not type BaseLoader")
26
+ return loader
27
+
28
+
29
+ class BaseLoader:
30
+ """
31
+ Base class for building custom loaders for use with Simple-Track. To use, inherit
32
+ from this class and implement the `user_definable_load` method, which will take a
33
+ single input (filename) and should return a list of [datetime, array].
34
+ The loader should be initialised with a list of filenames, which will be
35
+ iterated through when the loader is used in Simple-Track.
36
+ Loaded data is checked for consistency and type before being passed to Simple-Track,
37
+ so the user only needs to worry about loading the data in the correct format.
38
+
39
+ Loaders should be
40
+ """
41
+
42
+ def __init__(self, input_data: Union[list[str] | dict]) -> None:
43
+ self.domain_shape = None
44
+ self.input_data = input_data
45
+ # Set the iterating list
46
+ if not isinstance(input_data, (list, tuple)):
47
+ raise TypeError(f"Expected input_data type list, got {type(input_data)}")
48
+
49
+ def __iter__(self):
50
+ self.iter_idx = 0
51
+ return self
52
+
53
+ def __next__(self) -> list[dt.datetime, NDArray]:
54
+ if self.iter_idx >= len(self.input_data):
55
+ raise StopIteration
56
+ next_fnm = self.input_data[self.iter_idx]
57
+ self.iter_idx += 1
58
+ time, data = self.user_definable_load(next_fnm)
59
+ self._check_loaded_data(time, data)
60
+ return time, data
61
+
62
+ # TODO: rename this to something better?
63
+ def user_definable_load(self, filename: str) -> list[dt.datetime, NDArray]:
64
+ raise NotImplementedError
65
+
66
+ def _check_loaded_data(
67
+ self,
68
+ output_time: dt.datetime,
69
+ output_arr: NDArray,
70
+ ) -> None:
71
+ # Check consistency of data shape
72
+ if self.domain_shape is None:
73
+ self.domain_shape = output_arr.shape
74
+ output_arr = check_arrays(output_arr, shape=self.domain_shape, ndim=2)
75
+
76
+ # Check output time is a sensible type
77
+ if not isinstance(output_time, dt.datetime):
78
+ raise TypeError(
79
+ f"Expected 'output_time' to be datetime object, got {type(output_time)}"
80
+ )
81
+
82
+
83
+ class DictIterator(BaseLoader):
84
+ """
85
+ An alternative loading solution for users wish to load and/or pre-process their data
86
+ elsewhere and pass it directly to Simple-Track. The input should be a dictionary
87
+ with datetime keys and 2D array values. This will then iteratre through
88
+ the dictionary in datetime order.
89
+ """
90
+
91
+ def __init__(self, input_dict: dict) -> None:
92
+ self.domain_shape = None
93
+ self.input_data = input_dict
94
+ # Set the iterating list
95
+ if not isinstance(input_dict, dict):
96
+ raise TypeError(f"Expected input_data type dict, got {type(input_dict)}")
97
+ self.iterator = sorted(input_dict.keys())
98
+ if not all([isinstance(key, dt.datetime) for key in self.iterator]):
99
+ raise TypeError("Expected all input keys to be of type dt.datetime")
100
+
101
+ def __next__(self) -> list[NDArray, dt.datetime]:
102
+ if self.iter_idx >= len(self.iterator):
103
+ raise StopIteration
104
+ time = self.iterator[self.iter_idx]
105
+ data = self.input_data[time]
106
+ self.iter_idx += 1
107
+ self._check_loaded_data(time, data)
108
+ return time, data
109
+
110
+
111
+ class MWELoader(BaseLoader):
112
+ def __init__(self, filenames: list):
113
+ super().__init__(filenames)
114
+
115
+ def user_definable_load(self, filename):
116
+ import numpy as np
117
+
118
+ base_time = dt.datetime(2024, 1, 1, 0, 0, 0)
119
+ data = np.loadtxt(filename)
120
+ self.file_id = str(filename)
121
+ mwe_idx = str(filename)[-7]
122
+ time = base_time + dt.timedelta(minutes=5 * int(mwe_idx))
123
+ return time, data
124
+
125
+
126
+ class ChilboltonLoader(BaseLoader):
127
+ def __init__(self, filenames: list):
128
+ super().__init__(filenames)
129
+
130
+ def user_definable_load(self, filename):
131
+ import numpy as np
132
+ from netCDF4 import Dataset as ncfile
133
+
134
+ nc = ncfile(filename)
135
+ data = nc.variables["var"][200:600, 250:550] / 32
136
+ data = np.flipud(np.transpose(data))
137
+ date_id = str(filename)[-18:-11]
138
+ time_id = str(filename)[-9:-5]
139
+ time = dt.datetime(
140
+ year=int(date_id[0:4]),
141
+ month=int(date_id[4:6]),
142
+ day=int(date_id[6:]),
143
+ hour=int(time_id[0:2]),
144
+ minute=int(time_id[2:4]),
145
+ )
146
+ return time, data
147
+
148
+
149
+ class LoadingBar:
150
+ """
151
+ Class for displaying a loading bar in the terminal. Initialised with the total
152
+ number of items to load and the length of the loading bar, The
153
+ "update_progress" method is then called to update the current progress
154
+ """
155
+
156
+ def __init__(self, total, bar_length=20):
157
+ self.total = total
158
+ self.bar_length = bar_length
159
+ init_padding = int(self.bar_length) * " "
160
+ print(f"Simple-Track Progress: [{init_padding}] 0/{self.total} (0%)", end="\r")
161
+
162
+ def update_progress(self, current):
163
+ fraction = current / self.total
164
+ arrow = int(fraction * self.bar_length - 1) * "-" + ">"
165
+ padding = int(self.bar_length - len(arrow)) * " "
166
+ ending = "\n" if current == self.total else "\r"
167
+ print(
168
+ f"Simple-Track Progress: [{arrow}{padding}] {current}/{self.total} ({int(fraction * 100)}%) ",
169
+ end=ending,
170
+ )
@@ -0,0 +1,12 @@
1
+ import sys
2
+
3
+ from track import Tracker
4
+
5
+ if __name__ == "__main__":
6
+ if len(sys.argv) < 2:
7
+ raise Exception("Running SimpleTrack requires path to at least one config")
8
+
9
+ config_paths = sys.argv[1:]
10
+ for config_path in config_paths:
11
+ # With None passed into run method, uses input path in config
12
+ Tracker(config_path).run()
simpletrack/track.py ADDED
@@ -0,0 +1,281 @@
1
+ """
2
+ Run the SimpleTrack algorithm to track objects through a sequence of images
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Union
7
+
8
+ from yaml import safe_load
9
+
10
+ from simpletrack.flow_solver import FlowSolver
11
+ from simpletrack.frame import Frame, Timeline
12
+ from simpletrack.frame_output import FrameOutputManager
13
+ from simpletrack.frame_tracker import FrameTracker
14
+ from simpletrack.load import ConfigError, DictIterator, LoadingBar, get_loader
15
+
16
+
17
+ class Tracker:
18
+ """
19
+ Simple-Track manager controlling inputs, processing, outputs
20
+ """
21
+
22
+ def __init__(self, config_input: Union[str | dict]) -> None:
23
+ """
24
+ Initialize SimpleTrack with configuration file
25
+
26
+ Args:
27
+ config_iput (str|dict):
28
+ If str, provides Path to the configuration file
29
+ If dict, containts pre-loaded config parameters
30
+ """
31
+ if isinstance(config_input, str):
32
+ config_path = config_input
33
+ self.config = self._read_config(config_input)
34
+ elif isinstance(config_input, dict):
35
+ config_path = None
36
+ self._check_config(config_input)
37
+ self.config = config_input
38
+ else:
39
+ raise TypeError(
40
+ f"Expected config_input type str or dict, got {type(config_input)}"
41
+ )
42
+
43
+ self.start_time = None # Will be set during run()
44
+ self.timeline = Timeline()
45
+
46
+ if "INPUT" in self.config:
47
+ self.file_type = self.config["INPUT"].get("file_type", None)
48
+
49
+ if "FLOW_SOLVER" in self.config:
50
+ self.flow_solver = FlowSolver(**self.config["FLOW_SOLVER"])
51
+ else:
52
+ self.flow_solver = FlowSolver()
53
+
54
+ if "TRACKING" in self.config:
55
+ self.frame_tracker = FrameTracker(**self.config["TRACKING"])
56
+ self.skip_tracking = self.config["TRACKING"].get("skip_tracking", False)
57
+ else:
58
+ self.frame_tracker = FrameTracker()
59
+ self.skip_tracking = False
60
+
61
+ if "OUTPUT" in self.config:
62
+ output_path = self.config["OUTPUT"].get("path", "./output")
63
+ expt_name = self.config["OUTPUT"].get(
64
+ "experiment_name", "Simple-Track Experiment"
65
+ )
66
+
67
+ # Output only if flagged in config
68
+ self.frame_output = None
69
+ if "OUTPUT" in self.config:
70
+ if self.config["OUTPUT"]["save_data"]:
71
+ self.frame_output = FrameOutputManager(
72
+ output_path,
73
+ expt_name,
74
+ self.start_time,
75
+ config_path,
76
+ )
77
+
78
+ def run(self, input_data: Union[list[str] | dict] = None) -> Timeline:
79
+ """
80
+ Runs SimpleTrack using the designated config options.
81
+
82
+ Input data can either be read in from filenames (list(str)) or provided
83
+ as input using dictionary
84
+
85
+ If input_data is None, SimpleTrack finds all valid files in ["PATH"]["data]
86
+ config input using "SimpleTrack.get_filenames_from_input_path"
87
+
88
+ If data is being read in using filenames, there must also be an associated
89
+ Loader class argument in config["PATH"]["loader"] that defines how the data
90
+ should be pre-processed and how the validity time should be determined.
91
+ Filenames should be ordered by time. Loaded data will be checked for consistent
92
+ array shapes. See docs or src.load.py for more.
93
+
94
+ If data is being provided as input using dict, it should be passed
95
+ with the respective datetime object as the key, and the numpy array to run
96
+ tracking on as the value. This will not use a predetermined Loader class to
97
+ load the data, although the same checks on consistent array shapes
98
+ will be applied.
99
+
100
+ Returns Timeline object containing Frames of data and tracked Features.
101
+ """
102
+ # Get input files to load if inputs not provided
103
+ if input_data is None:
104
+ input_data = self.get_filenames_from_input_path(file_type=self.file_type)
105
+
106
+ # Check type of input data and set up loader accordingly
107
+ if isinstance(input_data, list):
108
+ valid_types = (str, Path)
109
+ if not all([isinstance(fnm, valid_types) for fnm in input_data]):
110
+ types = [type(fnm) for fnm in input_data]
111
+ raise TypeError(
112
+ f"If input_data is list it must only contain str, got {types}"
113
+ )
114
+ self.loading_bar = LoadingBar(total=len(input_data))
115
+ self.loader = get_loader(self.config["INPUT"]["loader"])(input_data)
116
+
117
+ elif isinstance(input_data, dict):
118
+ self.loading_bar = LoadingBar(total=len(input_data.values()))
119
+ self.loader = DictIterator(input_data)
120
+
121
+ else:
122
+ raise TypeError(
123
+ f"Expected input_data type list(str) or dict, got {type(input_data)}"
124
+ )
125
+ # print(f"Hello from proc {mp.current_process().name} with arg {filenames}\n")
126
+
127
+ # Iterate through sorted input data, perform tracking, output results if flagged
128
+ for fnm_idx, time_and_data in enumerate(self.loader):
129
+ if self.start_time is None:
130
+ self.start_time = time_and_data[0]
131
+
132
+ # Import data to Frame and add to Timeline
133
+ frame = Frame()
134
+ frame.import_time_and_data(*time_and_data)
135
+ frame.identify_features(**self.config["FEATURE"])
136
+ self.timeline.add_to_timelime(frame)
137
+
138
+ # If this is the first frame or tracking is disabled, skip tracking
139
+ if len(self.timeline.timeline) == 1 or self.skip_tracking:
140
+ self.loading_bar.update_progress(fnm_idx + 1)
141
+ # Output frame data to text file or npy file if flagged
142
+ if self.frame_output is not None:
143
+ self.frame_output.features_to_txt(frame)
144
+ self.frame_output.features_to_csv(frame)
145
+ self.frame_output.fields_to_npy(frame)
146
+ continue
147
+
148
+ # Now run flow solver between previous and current frame
149
+ prev_frame = self.timeline.get_previous_frame(frame.time)
150
+ # Set max id for assigning to new features
151
+ frame.max_id = prev_frame.max_id
152
+ # Get the flow field that translates features between the two frames
153
+ y_flow, x_flow = self.flow_solver.analyse_flow(prev_frame, frame)
154
+
155
+ # Update the current Frame with these displacements
156
+ if y_flow is not None or x_flow is not None:
157
+ frame.assign_displacements(y_flow, x_flow)
158
+
159
+ # Match Features between Frames
160
+ self.frame_tracker.run(prev_frame, frame)
161
+
162
+ # Output frame data to text file and field to npy if flagged
163
+ if self.frame_output is not None:
164
+ self.frame_output.features_to_txt(frame)
165
+ self.frame_output.features_to_csv(frame)
166
+ self.frame_output.fields_to_npy(frame)
167
+
168
+ self.loading_bar.update_progress(fnm_idx + 1)
169
+
170
+ # Output additional fields if flagged
171
+ if self.frame_output is not None:
172
+ self.frame_output.output_density_field(
173
+ self.timeline, "init", centroid_only=False
174
+ )
175
+ self.frame_output.output_density_field(
176
+ self.timeline, "dissipation", centroid_only=False
177
+ )
178
+ return self.timeline
179
+
180
+ # def run_parallel(self, processes=4):
181
+ # # Split filenames into chunks for each process
182
+ # chunk_size = len(self.filenames) // processes
183
+ # filename_chunks = [
184
+ # self.filenames[i : i + chunk_size]
185
+ # for i in range(0, len(self.filenames), chunk_size)
186
+ # ]
187
+
188
+ # with mp.Pool(processes=processes) as pool:
189
+ # # TODO: figure out how to do this with the new version of run above, where
190
+ # # not having filename inputs means it tries to get it from config...
191
+ # pool.map(self.run, filename_chunks)
192
+
193
+ # # TODO: then need a way to make the results consistent between
194
+ # # different chunks.
195
+ # # I.e., if the last event of chunk 1 contains a storm that is
196
+ # # also present in the first event of chunk 2, then the chunk 2
197
+ # # storm needs to have a consistent ID, needs to have updated lifetimes
198
+ # # etc.
199
+ # # This is apparently already solved in Will Keats/Callum Scullion MO
200
+ # # code so don't need to reinvent the wheel here.
201
+
202
+ def get_filenames_from_input_path(
203
+ self, input_path: str = None, file_type: str = None
204
+ ) -> list:
205
+ """
206
+ Get a list of filenames from a given input path matching a given
207
+ file type
208
+
209
+ Args:
210
+ input_path (str, optional):
211
+ Input path to search for filenames
212
+ Defaults to self.config["INPUT"]["path"]
213
+ file_type (str, optional):
214
+ File type to search input_path for
215
+ Defaults to .nc
216
+ """
217
+ if input_path is None:
218
+ input_path = self.config["INPUT"]["path"]
219
+
220
+ supported_filetypes = [".nc"]
221
+ if file_type is not None:
222
+ if isinstance(file_type, str):
223
+ supported_filetypes.append(file_type)
224
+ elif isinstance(file_type, list):
225
+ if not all([isinstance(val, str) for val in file_type]):
226
+ types = [type(val) for val in file_type]
227
+ raise TypeError(f"Expected list to contain only str, got {types}")
228
+ for ftype in file_type:
229
+ supported_filetypes.append(ftype)
230
+ else:
231
+ raise TypeError(f"Expected list or str, got {type(file_type)}")
232
+
233
+ filenames = sorted(
234
+ [
235
+ p
236
+ for p in Path(input_path).iterdir()
237
+ if p.is_file() and p.suffix in supported_filetypes
238
+ ]
239
+ )
240
+ if len(filenames) == 0:
241
+ raise FileNotFoundError(f"No files found in directory: {input_path}")
242
+ return filenames
243
+
244
+ def _read_config(self, config_path: str) -> dict:
245
+ """
246
+ Read config, check for necessary arguments (threshold, data paths, loader),
247
+ return dict of parameters.
248
+
249
+ Args:
250
+ config_path (str):
251
+ Path to config
252
+
253
+ Returns:
254
+ dict:
255
+ Simple-Track parameters
256
+ """
257
+ with open(config_path) as input:
258
+ config = safe_load(input)
259
+ self._check_config(config)
260
+ return config
261
+
262
+ def _check_config(self, config: dict) -> None:
263
+ # Check required top-level sections are present
264
+ required_sections = ["FEATURE"]
265
+ input_section = config.keys()
266
+ section_check = [section in input_section for section in required_sections]
267
+ if not all(section_check):
268
+ raise ConfigError(
269
+ f"config missing one or more required sections: {required_sections}"
270
+ )
271
+ # # Check required parameters are present
272
+ # required_params = ["data"]
273
+ # input_keys = config["PATH"].keys()
274
+ # required_input_check = [key in input_keys for key in required_params]
275
+
276
+ # if not all(required_input_check):
277
+ # raise ConfigError(
278
+ # f"config missing one or more required inputs: {required_params}"
279
+ # )
280
+ if "threshold" not in config["FEATURE"]:
281
+ raise ConfigError("config missing required threshold input")
simpletrack/utils.py ADDED
@@ -0,0 +1,145 @@
1
+ import numpy as np
2
+
3
+ from simpletrack.exceptions import (
4
+ ArrayShapeError,
5
+ ArrayTypeError,
6
+ FloatIDError,
7
+ IDError,
8
+ NegativeIDError,
9
+ ZeroIDError,
10
+ )
11
+
12
+
13
+ def check_arrays(
14
+ *args, shape=None, ndim=None, dtype=None, equal_shape=False, non_negative=False
15
+ ):
16
+ # Check inputs args are array like, convert to numpy array if possible,
17
+ # otherwise return TypeError
18
+ modified_args = []
19
+ for arr in args:
20
+ if isinstance(arr, np.ndarray):
21
+ modified_args.append(arr)
22
+ elif isinstance(arr, (list, tuple)):
23
+ modified_args.append(np.array(arr))
24
+ else:
25
+ raise ArrayTypeError("args must be an array-like (array, list or tuple)")
26
+
27
+ # Check each array has the required shape
28
+ if shape is not None:
29
+ for arr in modified_args:
30
+ if arr.shape != shape:
31
+ msg = f"""
32
+ Argument with shape {arr.shape} does not have required shape {shape}
33
+ """
34
+ raise ArrayShapeError(msg)
35
+
36
+ # Check each array has required number of dimensions
37
+ if ndim is not None:
38
+ for arr in modified_args:
39
+ if arr.ndim != ndim:
40
+ msg = (
41
+ f"Argument with ndim {arr.ndim} does not have required ndim {ndim}"
42
+ )
43
+ raise ArrayShapeError(msg)
44
+
45
+ # Check each array has the required dtype
46
+ if dtype is not None:
47
+ # Change python base types to numpy types for looser comparison
48
+ if dtype is int:
49
+ np_dtype = np.integer
50
+ elif dtype is float:
51
+ np_dtype = np.floating
52
+ else:
53
+ raise ArrayTypeError(f"Unsupported dtype {dtype} for check_arrays")
54
+
55
+ for arr in modified_args:
56
+ if not np.issubdtype(arr.dtype, np_dtype):
57
+ try:
58
+ arr = arr.astype(dtype, casting="same_value")
59
+ except (ValueError, TypeError):
60
+ msg = f"""
61
+ Argument with dtype {arr.dtype} does not have and cannot be cast to
62
+ required dtype {dtype}
63
+ """
64
+ raise ArrayTypeError(msg) from None
65
+
66
+ # Check each input array is equal size
67
+ if equal_shape:
68
+ arr0_shape = args[0].shape
69
+ if not all([arr.shape == arr0_shape for arr in modified_args]):
70
+ msg = f"Input array shapes differ: {[arr.shape for arr in args]}"
71
+ raise ArrayShapeError(msg)
72
+
73
+ # Check all values are positive
74
+ if non_negative:
75
+ if not all([np.all(arr >= 0) for arr in modified_args]):
76
+ msg = "Expected inputs to contain non-negative values"
77
+ raise ArrayTypeError(msg)
78
+
79
+ # Don't want to return a single arg input as a list
80
+ if len(modified_args) == 1:
81
+ return modified_args[0]
82
+ else:
83
+ return modified_args
84
+
85
+
86
+ def check_valid_ids(*args):
87
+ """
88
+ Checks that all inputs (scalar or vector) contain valid id data - each element
89
+ is a positive, nonzero integer
90
+ """
91
+ modified_args = []
92
+ for arg in args:
93
+ if isinstance(arg, str):
94
+ raise IDError("Cannot interpret str as ID")
95
+ elif np.isscalar(arg):
96
+ arg_native = native(arg)
97
+ # Check if turning input into int would not change its value
98
+ # If so, continue checks with int version
99
+ if int(arg_native) == arg_native:
100
+ arg_native = int(arg_native)
101
+ if not np.issubdtype(type(arg_native), np.integer):
102
+ raise FloatIDError(f"{arg_native} not an int")
103
+ if arg_native == 0:
104
+ raise ZeroIDError("Valid IDs start from 1, got 0")
105
+ if arg_native < 0:
106
+ raise NegativeIDError(f"Valid IDs start from 1, got {arg_native}")
107
+ modified_args.append(arg_native)
108
+
109
+ else: # Looking at vector inputs
110
+ arg_array = np.array(arg) if isinstance(arg, (list, tuple)) else arg
111
+ if len(arg_array) == 0:
112
+ return []
113
+
114
+ # Check if turning input into int would not change its value
115
+ # If so, continue checks with int version
116
+ if np.all(arg_array.astype(int) == arg_array):
117
+ arg_array = arg_array.astype(int)
118
+
119
+ if not np.issubdtype(arg_array.dtype, np.integer):
120
+ raise FloatIDError(f"Array must contain ints only: {arg_array}")
121
+ if any(arg_array < 0):
122
+ raise NegativeIDError(
123
+ f"Array must contain positive ints only: {arg_array}"
124
+ )
125
+ modified_args.append(arg_array)
126
+
127
+ # Don't want to return a single arg input as a list
128
+ if len(modified_args) == 1:
129
+ return modified_args[0]
130
+ else:
131
+ return modified_args
132
+
133
+
134
+ def native(value):
135
+ """
136
+ Convert numpy scalar types to native python types.
137
+ If argument is already native, return unchanged
138
+
139
+ Args:
140
+ value (any): Input value
141
+
142
+ Returns:
143
+ any: Converted value
144
+ """
145
+ return getattr(value, "tolist", lambda: value)()