Simple-Track 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simple_track-2.0.0.dist-info/METADATA +218 -0
- simple_track-2.0.0.dist-info/RECORD +17 -0
- simple_track-2.0.0.dist-info/WHEEL +5 -0
- simple_track-2.0.0.dist-info/entry_points.txt +2 -0
- simple_track-2.0.0.dist-info/licenses/LICENSE +373 -0
- simple_track-2.0.0.dist-info/top_level.txt +1 -0
- simpletrack/__init__.py +1 -0
- simpletrack/exceptions.py +51 -0
- simpletrack/feature.py +322 -0
- simpletrack/flow_solver.py +589 -0
- simpletrack/frame.py +521 -0
- simpletrack/frame_output.py +295 -0
- simpletrack/frame_tracker.py +962 -0
- simpletrack/load.py +170 -0
- simpletrack/run_simple_track.py +12 -0
- simpletrack/track.py +281 -0
- simpletrack/utils.py +145 -0
simpletrack/load.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import datetime as dt
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
from simpletrack.utils import check_arrays
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ConfigError(Exception):
|
|
10
|
+
"""
|
|
11
|
+
Error thrown when one or more config input parameters are not valid
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_loader(loader_key: str):
|
|
16
|
+
available_loaders = {
|
|
17
|
+
"MWELoader": MWELoader,
|
|
18
|
+
"ChilboltonLoader": ChilboltonLoader,
|
|
19
|
+
}
|
|
20
|
+
try:
|
|
21
|
+
loader = available_loaders[loader_key]
|
|
22
|
+
except KeyError as err:
|
|
23
|
+
raise KeyError(f"Unknown loader: {loader_key}") from err
|
|
24
|
+
if not issubclass(loader, BaseLoader):
|
|
25
|
+
raise TypeError(f"Requested loader ({loader}) is not type BaseLoader")
|
|
26
|
+
return loader
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class BaseLoader:
|
|
30
|
+
"""
|
|
31
|
+
Base class for building custom loaders for use with Simple-Track. To use, inherit
|
|
32
|
+
from this class and implement the `user_definable_load` method, which will take a
|
|
33
|
+
single input (filename) and should return a list of [datetime, array].
|
|
34
|
+
The loader should be initialised with a list of filenames, which will be
|
|
35
|
+
iterated through when the loader is used in Simple-Track.
|
|
36
|
+
Loaded data is checked for consistency and type before being passed to Simple-Track,
|
|
37
|
+
so the user only needs to worry about loading the data in the correct format.
|
|
38
|
+
|
|
39
|
+
Loaders should be
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, input_data: Union[list[str] | dict]) -> None:
|
|
43
|
+
self.domain_shape = None
|
|
44
|
+
self.input_data = input_data
|
|
45
|
+
# Set the iterating list
|
|
46
|
+
if not isinstance(input_data, (list, tuple)):
|
|
47
|
+
raise TypeError(f"Expected input_data type list, got {type(input_data)}")
|
|
48
|
+
|
|
49
|
+
def __iter__(self):
|
|
50
|
+
self.iter_idx = 0
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def __next__(self) -> list[dt.datetime, NDArray]:
|
|
54
|
+
if self.iter_idx >= len(self.input_data):
|
|
55
|
+
raise StopIteration
|
|
56
|
+
next_fnm = self.input_data[self.iter_idx]
|
|
57
|
+
self.iter_idx += 1
|
|
58
|
+
time, data = self.user_definable_load(next_fnm)
|
|
59
|
+
self._check_loaded_data(time, data)
|
|
60
|
+
return time, data
|
|
61
|
+
|
|
62
|
+
# TODO: rename this to something better?
|
|
63
|
+
def user_definable_load(self, filename: str) -> list[dt.datetime, NDArray]:
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
def _check_loaded_data(
|
|
67
|
+
self,
|
|
68
|
+
output_time: dt.datetime,
|
|
69
|
+
output_arr: NDArray,
|
|
70
|
+
) -> None:
|
|
71
|
+
# Check consistency of data shape
|
|
72
|
+
if self.domain_shape is None:
|
|
73
|
+
self.domain_shape = output_arr.shape
|
|
74
|
+
output_arr = check_arrays(output_arr, shape=self.domain_shape, ndim=2)
|
|
75
|
+
|
|
76
|
+
# Check output time is a sensible type
|
|
77
|
+
if not isinstance(output_time, dt.datetime):
|
|
78
|
+
raise TypeError(
|
|
79
|
+
f"Expected 'output_time' to be datetime object, got {type(output_time)}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class DictIterator(BaseLoader):
|
|
84
|
+
"""
|
|
85
|
+
An alternative loading solution for users wish to load and/or pre-process their data
|
|
86
|
+
elsewhere and pass it directly to Simple-Track. The input should be a dictionary
|
|
87
|
+
with datetime keys and 2D array values. This will then iteratre through
|
|
88
|
+
the dictionary in datetime order.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self, input_dict: dict) -> None:
|
|
92
|
+
self.domain_shape = None
|
|
93
|
+
self.input_data = input_dict
|
|
94
|
+
# Set the iterating list
|
|
95
|
+
if not isinstance(input_dict, dict):
|
|
96
|
+
raise TypeError(f"Expected input_data type dict, got {type(input_dict)}")
|
|
97
|
+
self.iterator = sorted(input_dict.keys())
|
|
98
|
+
if not all([isinstance(key, dt.datetime) for key in self.iterator]):
|
|
99
|
+
raise TypeError("Expected all input keys to be of type dt.datetime")
|
|
100
|
+
|
|
101
|
+
def __next__(self) -> list[NDArray, dt.datetime]:
|
|
102
|
+
if self.iter_idx >= len(self.iterator):
|
|
103
|
+
raise StopIteration
|
|
104
|
+
time = self.iterator[self.iter_idx]
|
|
105
|
+
data = self.input_data[time]
|
|
106
|
+
self.iter_idx += 1
|
|
107
|
+
self._check_loaded_data(time, data)
|
|
108
|
+
return time, data
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class MWELoader(BaseLoader):
|
|
112
|
+
def __init__(self, filenames: list):
|
|
113
|
+
super().__init__(filenames)
|
|
114
|
+
|
|
115
|
+
def user_definable_load(self, filename):
|
|
116
|
+
import numpy as np
|
|
117
|
+
|
|
118
|
+
base_time = dt.datetime(2024, 1, 1, 0, 0, 0)
|
|
119
|
+
data = np.loadtxt(filename)
|
|
120
|
+
self.file_id = str(filename)
|
|
121
|
+
mwe_idx = str(filename)[-7]
|
|
122
|
+
time = base_time + dt.timedelta(minutes=5 * int(mwe_idx))
|
|
123
|
+
return time, data
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class ChilboltonLoader(BaseLoader):
|
|
127
|
+
def __init__(self, filenames: list):
|
|
128
|
+
super().__init__(filenames)
|
|
129
|
+
|
|
130
|
+
def user_definable_load(self, filename):
|
|
131
|
+
import numpy as np
|
|
132
|
+
from netCDF4 import Dataset as ncfile
|
|
133
|
+
|
|
134
|
+
nc = ncfile(filename)
|
|
135
|
+
data = nc.variables["var"][200:600, 250:550] / 32
|
|
136
|
+
data = np.flipud(np.transpose(data))
|
|
137
|
+
date_id = str(filename)[-18:-11]
|
|
138
|
+
time_id = str(filename)[-9:-5]
|
|
139
|
+
time = dt.datetime(
|
|
140
|
+
year=int(date_id[0:4]),
|
|
141
|
+
month=int(date_id[4:6]),
|
|
142
|
+
day=int(date_id[6:]),
|
|
143
|
+
hour=int(time_id[0:2]),
|
|
144
|
+
minute=int(time_id[2:4]),
|
|
145
|
+
)
|
|
146
|
+
return time, data
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class LoadingBar:
|
|
150
|
+
"""
|
|
151
|
+
Class for displaying a loading bar in the terminal. Initialised with the total
|
|
152
|
+
number of items to load and the length of the loading bar, The
|
|
153
|
+
"update_progress" method is then called to update the current progress
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(self, total, bar_length=20):
|
|
157
|
+
self.total = total
|
|
158
|
+
self.bar_length = bar_length
|
|
159
|
+
init_padding = int(self.bar_length) * " "
|
|
160
|
+
print(f"Simple-Track Progress: [{init_padding}] 0/{self.total} (0%)", end="\r")
|
|
161
|
+
|
|
162
|
+
def update_progress(self, current):
|
|
163
|
+
fraction = current / self.total
|
|
164
|
+
arrow = int(fraction * self.bar_length - 1) * "-" + ">"
|
|
165
|
+
padding = int(self.bar_length - len(arrow)) * " "
|
|
166
|
+
ending = "\n" if current == self.total else "\r"
|
|
167
|
+
print(
|
|
168
|
+
f"Simple-Track Progress: [{arrow}{padding}] {current}/{self.total} ({int(fraction * 100)}%) ",
|
|
169
|
+
end=ending,
|
|
170
|
+
)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
from track import Tracker
|
|
4
|
+
|
|
5
|
+
if __name__ == "__main__":
|
|
6
|
+
if len(sys.argv) < 2:
|
|
7
|
+
raise Exception("Running SimpleTrack requires path to at least one config")
|
|
8
|
+
|
|
9
|
+
config_paths = sys.argv[1:]
|
|
10
|
+
for config_path in config_paths:
|
|
11
|
+
# With None passed into run method, uses input path in config
|
|
12
|
+
Tracker(config_path).run()
|
simpletrack/track.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Run the SimpleTrack algorithm to track objects through a sequence of images
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Union
|
|
7
|
+
|
|
8
|
+
from yaml import safe_load
|
|
9
|
+
|
|
10
|
+
from simpletrack.flow_solver import FlowSolver
|
|
11
|
+
from simpletrack.frame import Frame, Timeline
|
|
12
|
+
from simpletrack.frame_output import FrameOutputManager
|
|
13
|
+
from simpletrack.frame_tracker import FrameTracker
|
|
14
|
+
from simpletrack.load import ConfigError, DictIterator, LoadingBar, get_loader
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Tracker:
|
|
18
|
+
"""
|
|
19
|
+
Simple-Track manager controlling inputs, processing, outputs
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, config_input: Union[str | dict]) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Initialize SimpleTrack with configuration file
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
config_iput (str|dict):
|
|
28
|
+
If str, provides Path to the configuration file
|
|
29
|
+
If dict, containts pre-loaded config parameters
|
|
30
|
+
"""
|
|
31
|
+
if isinstance(config_input, str):
|
|
32
|
+
config_path = config_input
|
|
33
|
+
self.config = self._read_config(config_input)
|
|
34
|
+
elif isinstance(config_input, dict):
|
|
35
|
+
config_path = None
|
|
36
|
+
self._check_config(config_input)
|
|
37
|
+
self.config = config_input
|
|
38
|
+
else:
|
|
39
|
+
raise TypeError(
|
|
40
|
+
f"Expected config_input type str or dict, got {type(config_input)}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
self.start_time = None # Will be set during run()
|
|
44
|
+
self.timeline = Timeline()
|
|
45
|
+
|
|
46
|
+
if "INPUT" in self.config:
|
|
47
|
+
self.file_type = self.config["INPUT"].get("file_type", None)
|
|
48
|
+
|
|
49
|
+
if "FLOW_SOLVER" in self.config:
|
|
50
|
+
self.flow_solver = FlowSolver(**self.config["FLOW_SOLVER"])
|
|
51
|
+
else:
|
|
52
|
+
self.flow_solver = FlowSolver()
|
|
53
|
+
|
|
54
|
+
if "TRACKING" in self.config:
|
|
55
|
+
self.frame_tracker = FrameTracker(**self.config["TRACKING"])
|
|
56
|
+
self.skip_tracking = self.config["TRACKING"].get("skip_tracking", False)
|
|
57
|
+
else:
|
|
58
|
+
self.frame_tracker = FrameTracker()
|
|
59
|
+
self.skip_tracking = False
|
|
60
|
+
|
|
61
|
+
if "OUTPUT" in self.config:
|
|
62
|
+
output_path = self.config["OUTPUT"].get("path", "./output")
|
|
63
|
+
expt_name = self.config["OUTPUT"].get(
|
|
64
|
+
"experiment_name", "Simple-Track Experiment"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Output only if flagged in config
|
|
68
|
+
self.frame_output = None
|
|
69
|
+
if "OUTPUT" in self.config:
|
|
70
|
+
if self.config["OUTPUT"]["save_data"]:
|
|
71
|
+
self.frame_output = FrameOutputManager(
|
|
72
|
+
output_path,
|
|
73
|
+
expt_name,
|
|
74
|
+
self.start_time,
|
|
75
|
+
config_path,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def run(self, input_data: Union[list[str] | dict] = None) -> Timeline:
|
|
79
|
+
"""
|
|
80
|
+
Runs SimpleTrack using the designated config options.
|
|
81
|
+
|
|
82
|
+
Input data can either be read in from filenames (list(str)) or provided
|
|
83
|
+
as input using dictionary
|
|
84
|
+
|
|
85
|
+
If input_data is None, SimpleTrack finds all valid files in ["PATH"]["data]
|
|
86
|
+
config input using "SimpleTrack.get_filenames_from_input_path"
|
|
87
|
+
|
|
88
|
+
If data is being read in using filenames, there must also be an associated
|
|
89
|
+
Loader class argument in config["PATH"]["loader"] that defines how the data
|
|
90
|
+
should be pre-processed and how the validity time should be determined.
|
|
91
|
+
Filenames should be ordered by time. Loaded data will be checked for consistent
|
|
92
|
+
array shapes. See docs or src.load.py for more.
|
|
93
|
+
|
|
94
|
+
If data is being provided as input using dict, it should be passed
|
|
95
|
+
with the respective datetime object as the key, and the numpy array to run
|
|
96
|
+
tracking on as the value. This will not use a predetermined Loader class to
|
|
97
|
+
load the data, although the same checks on consistent array shapes
|
|
98
|
+
will be applied.
|
|
99
|
+
|
|
100
|
+
Returns Timeline object containing Frames of data and tracked Features.
|
|
101
|
+
"""
|
|
102
|
+
# Get input files to load if inputs not provided
|
|
103
|
+
if input_data is None:
|
|
104
|
+
input_data = self.get_filenames_from_input_path(file_type=self.file_type)
|
|
105
|
+
|
|
106
|
+
# Check type of input data and set up loader accordingly
|
|
107
|
+
if isinstance(input_data, list):
|
|
108
|
+
valid_types = (str, Path)
|
|
109
|
+
if not all([isinstance(fnm, valid_types) for fnm in input_data]):
|
|
110
|
+
types = [type(fnm) for fnm in input_data]
|
|
111
|
+
raise TypeError(
|
|
112
|
+
f"If input_data is list it must only contain str, got {types}"
|
|
113
|
+
)
|
|
114
|
+
self.loading_bar = LoadingBar(total=len(input_data))
|
|
115
|
+
self.loader = get_loader(self.config["INPUT"]["loader"])(input_data)
|
|
116
|
+
|
|
117
|
+
elif isinstance(input_data, dict):
|
|
118
|
+
self.loading_bar = LoadingBar(total=len(input_data.values()))
|
|
119
|
+
self.loader = DictIterator(input_data)
|
|
120
|
+
|
|
121
|
+
else:
|
|
122
|
+
raise TypeError(
|
|
123
|
+
f"Expected input_data type list(str) or dict, got {type(input_data)}"
|
|
124
|
+
)
|
|
125
|
+
# print(f"Hello from proc {mp.current_process().name} with arg {filenames}\n")
|
|
126
|
+
|
|
127
|
+
# Iterate through sorted input data, perform tracking, output results if flagged
|
|
128
|
+
for fnm_idx, time_and_data in enumerate(self.loader):
|
|
129
|
+
if self.start_time is None:
|
|
130
|
+
self.start_time = time_and_data[0]
|
|
131
|
+
|
|
132
|
+
# Import data to Frame and add to Timeline
|
|
133
|
+
frame = Frame()
|
|
134
|
+
frame.import_time_and_data(*time_and_data)
|
|
135
|
+
frame.identify_features(**self.config["FEATURE"])
|
|
136
|
+
self.timeline.add_to_timelime(frame)
|
|
137
|
+
|
|
138
|
+
# If this is the first frame or tracking is disabled, skip tracking
|
|
139
|
+
if len(self.timeline.timeline) == 1 or self.skip_tracking:
|
|
140
|
+
self.loading_bar.update_progress(fnm_idx + 1)
|
|
141
|
+
# Output frame data to text file or npy file if flagged
|
|
142
|
+
if self.frame_output is not None:
|
|
143
|
+
self.frame_output.features_to_txt(frame)
|
|
144
|
+
self.frame_output.features_to_csv(frame)
|
|
145
|
+
self.frame_output.fields_to_npy(frame)
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
# Now run flow solver between previous and current frame
|
|
149
|
+
prev_frame = self.timeline.get_previous_frame(frame.time)
|
|
150
|
+
# Set max id for assigning to new features
|
|
151
|
+
frame.max_id = prev_frame.max_id
|
|
152
|
+
# Get the flow field that translates features between the two frames
|
|
153
|
+
y_flow, x_flow = self.flow_solver.analyse_flow(prev_frame, frame)
|
|
154
|
+
|
|
155
|
+
# Update the current Frame with these displacements
|
|
156
|
+
if y_flow is not None or x_flow is not None:
|
|
157
|
+
frame.assign_displacements(y_flow, x_flow)
|
|
158
|
+
|
|
159
|
+
# Match Features between Frames
|
|
160
|
+
self.frame_tracker.run(prev_frame, frame)
|
|
161
|
+
|
|
162
|
+
# Output frame data to text file and field to npy if flagged
|
|
163
|
+
if self.frame_output is not None:
|
|
164
|
+
self.frame_output.features_to_txt(frame)
|
|
165
|
+
self.frame_output.features_to_csv(frame)
|
|
166
|
+
self.frame_output.fields_to_npy(frame)
|
|
167
|
+
|
|
168
|
+
self.loading_bar.update_progress(fnm_idx + 1)
|
|
169
|
+
|
|
170
|
+
# Output additional fields if flagged
|
|
171
|
+
if self.frame_output is not None:
|
|
172
|
+
self.frame_output.output_density_field(
|
|
173
|
+
self.timeline, "init", centroid_only=False
|
|
174
|
+
)
|
|
175
|
+
self.frame_output.output_density_field(
|
|
176
|
+
self.timeline, "dissipation", centroid_only=False
|
|
177
|
+
)
|
|
178
|
+
return self.timeline
|
|
179
|
+
|
|
180
|
+
# def run_parallel(self, processes=4):
|
|
181
|
+
# # Split filenames into chunks for each process
|
|
182
|
+
# chunk_size = len(self.filenames) // processes
|
|
183
|
+
# filename_chunks = [
|
|
184
|
+
# self.filenames[i : i + chunk_size]
|
|
185
|
+
# for i in range(0, len(self.filenames), chunk_size)
|
|
186
|
+
# ]
|
|
187
|
+
|
|
188
|
+
# with mp.Pool(processes=processes) as pool:
|
|
189
|
+
# # TODO: figure out how to do this with the new version of run above, where
|
|
190
|
+
# # not having filename inputs means it tries to get it from config...
|
|
191
|
+
# pool.map(self.run, filename_chunks)
|
|
192
|
+
|
|
193
|
+
# # TODO: then need a way to make the results consistent between
|
|
194
|
+
# # different chunks.
|
|
195
|
+
# # I.e., if the last event of chunk 1 contains a storm that is
|
|
196
|
+
# # also present in the first event of chunk 2, then the chunk 2
|
|
197
|
+
# # storm needs to have a consistent ID, needs to have updated lifetimes
|
|
198
|
+
# # etc.
|
|
199
|
+
# # This is apparently already solved in Will Keats/Callum Scullion MO
|
|
200
|
+
# # code so don't need to reinvent the wheel here.
|
|
201
|
+
|
|
202
|
+
def get_filenames_from_input_path(
|
|
203
|
+
self, input_path: str = None, file_type: str = None
|
|
204
|
+
) -> list:
|
|
205
|
+
"""
|
|
206
|
+
Get a list of filenames from a given input path matching a given
|
|
207
|
+
file type
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
input_path (str, optional):
|
|
211
|
+
Input path to search for filenames
|
|
212
|
+
Defaults to self.config["INPUT"]["path"]
|
|
213
|
+
file_type (str, optional):
|
|
214
|
+
File type to search input_path for
|
|
215
|
+
Defaults to .nc
|
|
216
|
+
"""
|
|
217
|
+
if input_path is None:
|
|
218
|
+
input_path = self.config["INPUT"]["path"]
|
|
219
|
+
|
|
220
|
+
supported_filetypes = [".nc"]
|
|
221
|
+
if file_type is not None:
|
|
222
|
+
if isinstance(file_type, str):
|
|
223
|
+
supported_filetypes.append(file_type)
|
|
224
|
+
elif isinstance(file_type, list):
|
|
225
|
+
if not all([isinstance(val, str) for val in file_type]):
|
|
226
|
+
types = [type(val) for val in file_type]
|
|
227
|
+
raise TypeError(f"Expected list to contain only str, got {types}")
|
|
228
|
+
for ftype in file_type:
|
|
229
|
+
supported_filetypes.append(ftype)
|
|
230
|
+
else:
|
|
231
|
+
raise TypeError(f"Expected list or str, got {type(file_type)}")
|
|
232
|
+
|
|
233
|
+
filenames = sorted(
|
|
234
|
+
[
|
|
235
|
+
p
|
|
236
|
+
for p in Path(input_path).iterdir()
|
|
237
|
+
if p.is_file() and p.suffix in supported_filetypes
|
|
238
|
+
]
|
|
239
|
+
)
|
|
240
|
+
if len(filenames) == 0:
|
|
241
|
+
raise FileNotFoundError(f"No files found in directory: {input_path}")
|
|
242
|
+
return filenames
|
|
243
|
+
|
|
244
|
+
def _read_config(self, config_path: str) -> dict:
|
|
245
|
+
"""
|
|
246
|
+
Read config, check for necessary arguments (threshold, data paths, loader),
|
|
247
|
+
return dict of parameters.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
config_path (str):
|
|
251
|
+
Path to config
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
dict:
|
|
255
|
+
Simple-Track parameters
|
|
256
|
+
"""
|
|
257
|
+
with open(config_path) as input:
|
|
258
|
+
config = safe_load(input)
|
|
259
|
+
self._check_config(config)
|
|
260
|
+
return config
|
|
261
|
+
|
|
262
|
+
def _check_config(self, config: dict) -> None:
|
|
263
|
+
# Check required top-level sections are present
|
|
264
|
+
required_sections = ["FEATURE"]
|
|
265
|
+
input_section = config.keys()
|
|
266
|
+
section_check = [section in input_section for section in required_sections]
|
|
267
|
+
if not all(section_check):
|
|
268
|
+
raise ConfigError(
|
|
269
|
+
f"config missing one or more required sections: {required_sections}"
|
|
270
|
+
)
|
|
271
|
+
# # Check required parameters are present
|
|
272
|
+
# required_params = ["data"]
|
|
273
|
+
# input_keys = config["PATH"].keys()
|
|
274
|
+
# required_input_check = [key in input_keys for key in required_params]
|
|
275
|
+
|
|
276
|
+
# if not all(required_input_check):
|
|
277
|
+
# raise ConfigError(
|
|
278
|
+
# f"config missing one or more required inputs: {required_params}"
|
|
279
|
+
# )
|
|
280
|
+
if "threshold" not in config["FEATURE"]:
|
|
281
|
+
raise ConfigError("config missing required threshold input")
|
simpletrack/utils.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from simpletrack.exceptions import (
|
|
4
|
+
ArrayShapeError,
|
|
5
|
+
ArrayTypeError,
|
|
6
|
+
FloatIDError,
|
|
7
|
+
IDError,
|
|
8
|
+
NegativeIDError,
|
|
9
|
+
ZeroIDError,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def check_arrays(
|
|
14
|
+
*args, shape=None, ndim=None, dtype=None, equal_shape=False, non_negative=False
|
|
15
|
+
):
|
|
16
|
+
# Check inputs args are array like, convert to numpy array if possible,
|
|
17
|
+
# otherwise return TypeError
|
|
18
|
+
modified_args = []
|
|
19
|
+
for arr in args:
|
|
20
|
+
if isinstance(arr, np.ndarray):
|
|
21
|
+
modified_args.append(arr)
|
|
22
|
+
elif isinstance(arr, (list, tuple)):
|
|
23
|
+
modified_args.append(np.array(arr))
|
|
24
|
+
else:
|
|
25
|
+
raise ArrayTypeError("args must be an array-like (array, list or tuple)")
|
|
26
|
+
|
|
27
|
+
# Check each array has the required shape
|
|
28
|
+
if shape is not None:
|
|
29
|
+
for arr in modified_args:
|
|
30
|
+
if arr.shape != shape:
|
|
31
|
+
msg = f"""
|
|
32
|
+
Argument with shape {arr.shape} does not have required shape {shape}
|
|
33
|
+
"""
|
|
34
|
+
raise ArrayShapeError(msg)
|
|
35
|
+
|
|
36
|
+
# Check each array has required number of dimensions
|
|
37
|
+
if ndim is not None:
|
|
38
|
+
for arr in modified_args:
|
|
39
|
+
if arr.ndim != ndim:
|
|
40
|
+
msg = (
|
|
41
|
+
f"Argument with ndim {arr.ndim} does not have required ndim {ndim}"
|
|
42
|
+
)
|
|
43
|
+
raise ArrayShapeError(msg)
|
|
44
|
+
|
|
45
|
+
# Check each array has the required dtype
|
|
46
|
+
if dtype is not None:
|
|
47
|
+
# Change python base types to numpy types for looser comparison
|
|
48
|
+
if dtype is int:
|
|
49
|
+
np_dtype = np.integer
|
|
50
|
+
elif dtype is float:
|
|
51
|
+
np_dtype = np.floating
|
|
52
|
+
else:
|
|
53
|
+
raise ArrayTypeError(f"Unsupported dtype {dtype} for check_arrays")
|
|
54
|
+
|
|
55
|
+
for arr in modified_args:
|
|
56
|
+
if not np.issubdtype(arr.dtype, np_dtype):
|
|
57
|
+
try:
|
|
58
|
+
arr = arr.astype(dtype, casting="same_value")
|
|
59
|
+
except (ValueError, TypeError):
|
|
60
|
+
msg = f"""
|
|
61
|
+
Argument with dtype {arr.dtype} does not have and cannot be cast to
|
|
62
|
+
required dtype {dtype}
|
|
63
|
+
"""
|
|
64
|
+
raise ArrayTypeError(msg) from None
|
|
65
|
+
|
|
66
|
+
# Check each input array is equal size
|
|
67
|
+
if equal_shape:
|
|
68
|
+
arr0_shape = args[0].shape
|
|
69
|
+
if not all([arr.shape == arr0_shape for arr in modified_args]):
|
|
70
|
+
msg = f"Input array shapes differ: {[arr.shape for arr in args]}"
|
|
71
|
+
raise ArrayShapeError(msg)
|
|
72
|
+
|
|
73
|
+
# Check all values are positive
|
|
74
|
+
if non_negative:
|
|
75
|
+
if not all([np.all(arr >= 0) for arr in modified_args]):
|
|
76
|
+
msg = "Expected inputs to contain non-negative values"
|
|
77
|
+
raise ArrayTypeError(msg)
|
|
78
|
+
|
|
79
|
+
# Don't want to return a single arg input as a list
|
|
80
|
+
if len(modified_args) == 1:
|
|
81
|
+
return modified_args[0]
|
|
82
|
+
else:
|
|
83
|
+
return modified_args
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def check_valid_ids(*args):
|
|
87
|
+
"""
|
|
88
|
+
Checks that all inputs (scalar or vector) contain valid id data - each element
|
|
89
|
+
is a positive, nonzero integer
|
|
90
|
+
"""
|
|
91
|
+
modified_args = []
|
|
92
|
+
for arg in args:
|
|
93
|
+
if isinstance(arg, str):
|
|
94
|
+
raise IDError("Cannot interpret str as ID")
|
|
95
|
+
elif np.isscalar(arg):
|
|
96
|
+
arg_native = native(arg)
|
|
97
|
+
# Check if turning input into int would not change its value
|
|
98
|
+
# If so, continue checks with int version
|
|
99
|
+
if int(arg_native) == arg_native:
|
|
100
|
+
arg_native = int(arg_native)
|
|
101
|
+
if not np.issubdtype(type(arg_native), np.integer):
|
|
102
|
+
raise FloatIDError(f"{arg_native} not an int")
|
|
103
|
+
if arg_native == 0:
|
|
104
|
+
raise ZeroIDError("Valid IDs start from 1, got 0")
|
|
105
|
+
if arg_native < 0:
|
|
106
|
+
raise NegativeIDError(f"Valid IDs start from 1, got {arg_native}")
|
|
107
|
+
modified_args.append(arg_native)
|
|
108
|
+
|
|
109
|
+
else: # Looking at vector inputs
|
|
110
|
+
arg_array = np.array(arg) if isinstance(arg, (list, tuple)) else arg
|
|
111
|
+
if len(arg_array) == 0:
|
|
112
|
+
return []
|
|
113
|
+
|
|
114
|
+
# Check if turning input into int would not change its value
|
|
115
|
+
# If so, continue checks with int version
|
|
116
|
+
if np.all(arg_array.astype(int) == arg_array):
|
|
117
|
+
arg_array = arg_array.astype(int)
|
|
118
|
+
|
|
119
|
+
if not np.issubdtype(arg_array.dtype, np.integer):
|
|
120
|
+
raise FloatIDError(f"Array must contain ints only: {arg_array}")
|
|
121
|
+
if any(arg_array < 0):
|
|
122
|
+
raise NegativeIDError(
|
|
123
|
+
f"Array must contain positive ints only: {arg_array}"
|
|
124
|
+
)
|
|
125
|
+
modified_args.append(arg_array)
|
|
126
|
+
|
|
127
|
+
# Don't want to return a single arg input as a list
|
|
128
|
+
if len(modified_args) == 1:
|
|
129
|
+
return modified_args[0]
|
|
130
|
+
else:
|
|
131
|
+
return modified_args
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def native(value):
|
|
135
|
+
"""
|
|
136
|
+
Convert numpy scalar types to native python types.
|
|
137
|
+
If argument is already native, return unchanged
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
value (any): Input value
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
any: Converted value
|
|
144
|
+
"""
|
|
145
|
+
return getattr(value, "tolist", lambda: value)()
|