Simple-Track 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
simpletrack/frame.py ADDED
@@ -0,0 +1,521 @@
1
+ import datetime as dt
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import scipy.ndimage as ndimage
6
+ from numpy.typing import NDArray
7
+
8
+ from simpletrack.exceptions import FeaturesNotFoundError
9
+ from simpletrack.feature import Feature
10
+ from simpletrack.utils import check_arrays, check_valid_ids
11
+
12
+
13
+ class Frame:
14
+ """
15
+ Object for storing data and methods related to a single timestep of data. This
16
+ includes the raw data, the feature field, the feature lifetime field, and a dict
17
+ of Feature objects for each feature identified in the frame. The Frame class also
18
+ includes methods for identifying features in the frame, and for assigning
19
+ pre-calculated motion vectors to each Feature.
20
+ """
21
+
22
+ def __init__(self):
23
+ self._time = None
24
+ self.raw_field = None
25
+ self._feature_field = None
26
+ self._lifetime_field = None
27
+ self._max_id = None
28
+ self._features = {}
29
+ self.y_flow = None
30
+ self.x_flow = None
31
+
32
+ def __repr__(self) -> str:
33
+ repr_str = f"Frame time: {self._time}, "
34
+ repr_str += f"Number of Features: {len(self._features)}"
35
+ return repr_str
36
+
37
+ def __eq__(self, other) -> bool:
38
+ if not isinstance(other, Frame):
39
+ return False
40
+ return self.time == other.time
41
+
42
+ @property
43
+ def features(self) -> dict:
44
+ """
45
+ Get all features identified in the frame as a dict, with the feature ids as the
46
+ dict keys and the corresponding Feature objects as the dicts vals
47
+ """
48
+ return self._features
49
+
50
+ @property
51
+ def time(self) -> dt.datetime:
52
+ """
53
+ Get the datetime object that the current frame is valid for
54
+ """
55
+ return self._time
56
+
57
+ @property
58
+ def feature_field(self) -> NDArray[np.integer]:
59
+ """
60
+ Get the feature id field for the current frame
61
+ """
62
+ return self._feature_field
63
+
64
+ @property
65
+ def lifetime_field(self) -> NDArray[np.integer]:
66
+ """
67
+ Get the feature lifetime field for the current frame
68
+ """
69
+ return self._lifetime_field
70
+
71
+ @property
72
+ def max_id(self) -> int:
73
+ """
74
+ Returns max_id of features in the frame.
75
+ """
76
+ return self._max_id
77
+
78
+ @features.setter
79
+ def features(self, features_dict: dict) -> None:
80
+ """
81
+ Set the features dict for the frame, with the feature ids as the
82
+ dict keys and the corresponding Feature objects as the dicts vals
83
+ """
84
+ if not isinstance(features_dict, dict):
85
+ raise TypeError(f"Expected type dict, got {type(features_dict)}")
86
+ self._features = features_dict
87
+
88
+ @time.setter
89
+ def time(self, time: dt.datetime) -> None:
90
+ """
91
+ Set time for the current frame, as a datetime.datetime object
92
+ """
93
+ if not isinstance(time, dt.datetime):
94
+ raise TypeError(
95
+ f"Expected 'output_time' to be datetime objcet, got {type(time)}"
96
+ )
97
+ self._time = time
98
+
99
+ @feature_field.setter
100
+ def feature_field(self, feature_field: NDArray) -> None:
101
+ """
102
+ Sets the self._feature_field attribute of the frame
103
+ """
104
+ self._feature_field = check_arrays(feature_field, ndim=2, dtype=int)
105
+
106
+ @max_id.setter
107
+ def max_id(self, max_id: int) -> None:
108
+ """
109
+ Sets the max_id used for assigning to features that do not match to another
110
+ feature from a previous timestep
111
+ """
112
+ max_id = check_valid_ids(max_id)
113
+ self._max_id = max_id
114
+
115
+ def get_feature(self, feature_id: int) -> Feature:
116
+ """
117
+ Get a feature matching the given id if present in the current field,
118
+ otherwise returns None.
119
+ """
120
+ feature_id = check_valid_ids(feature_id)
121
+ if feature_id in self._features:
122
+ return self._features[feature_id]
123
+ else:
124
+ return None
125
+
126
+ def get_flow(self) -> Union[NDArray, None]:
127
+ """
128
+ Get a list of the y-flow and x-flow fields derived by comparing features between
129
+ this frame and a frame from a previous timestep. Flow fields are both numpy
130
+ arrays, with order [y_flow, x_flow]. If flow was not previously derived,
131
+ returns [None, None]
132
+ """
133
+ return self.y_flow, self.x_flow
134
+
135
+ def import_time_and_data(self, time: dt.datetime, data: NDArray) -> None:
136
+ """
137
+ Load time and raw data into the frame.
138
+
139
+ Args:
140
+ time (dt.datetime): Time the frame is valid for.
141
+ data (NDArray): Raw data to perform tracking on
142
+ """
143
+ self.raw_field = check_arrays(data, ndim=2)
144
+ if not isinstance(time, dt.datetime):
145
+ raise TypeError(
146
+ f"Expected 'output_time' to be datetime objcet, got {type(time)}"
147
+ )
148
+ self._time = time
149
+
150
+ def identify_features(
151
+ self,
152
+ threshold: float,
153
+ under_threshold: bool = False,
154
+ min_size: int = 5,
155
+ ) -> None:
156
+ """
157
+ Call the "label_features" function to identify distinct regions in the input
158
+ field that meet a specified threshold condition.
159
+ Then, analyses each of the identified features to find properties
160
+
161
+ Args:
162
+ - min_size (float): Minimum area (in number of grid points) for a region
163
+ to be considered valid
164
+ - threshold (float): Threshold value for identifying regions
165
+ - under_threshold (bool): If True, regions under the threshold
166
+ are considered; if False, regions over the threshold are considered.
167
+ """
168
+ if self.raw_field is None:
169
+ raise Exception("Data has not been loaded into Frame")
170
+
171
+ self._feature_field = label_features(
172
+ field=self.raw_field,
173
+ min_area=min_size,
174
+ threshold=threshold,
175
+ under_threshold=under_threshold,
176
+ )
177
+ # Provisionally set the lifetime field to 1 anywhere there is a feature
178
+ self._lifetime_field = np.zeros_like(self._feature_field)
179
+ self._lifetime_field[self._feature_field > 0] = 1
180
+ self.max_id = int(np.max(self._feature_field))
181
+ self.populate_features()
182
+
183
+ def populate_features(self) -> None:
184
+ """
185
+ Uses the self._feature_field array to populate the self.features dict with new
186
+ Feature instances.
187
+ """
188
+ # Check for existing features dict
189
+ if self._features:
190
+ self._features = {}
191
+
192
+ if self._feature_field is None:
193
+ return
194
+
195
+ feature_ids = np.unique(self._feature_field)
196
+ # Remove 0 from the list of ids
197
+ # (usually this is at idx 0 but can't be guaranteed)
198
+ feature_ids = np.delete(feature_ids, np.where(feature_ids == 0)[0][0])
199
+ feature_ids = check_valid_ids(feature_ids)
200
+
201
+ # Don't include 0 in Feature population, this is reserved for background
202
+ for feature_id in feature_ids:
203
+ # Get the pixel locations of the feature in the field
204
+ # For 2D data, np.where returns two arrays containing y, x locations
205
+ feature_mask = np.where(self._feature_field == feature_id)
206
+ feature_coords = np.array(feature_mask)
207
+
208
+ # Construct Feature object, set relevant properties,
209
+ # add to the list of features
210
+ feature = Feature(
211
+ id=feature_id, feature_coords=feature_coords, time=self._time
212
+ )
213
+ # If raw field is not None, use this to find max value within Feature
214
+ if self.raw_field is not None:
215
+ feature.extreme = max(self.raw_field[feature_mask])
216
+ self._features[feature_id] = feature
217
+
218
+ def assign_displacements(self, y_flow: NDArray, x_flow: NDArray) -> None:
219
+ """
220
+ Add flow field to frame. Use input y_flow and x_flow fields to assign
221
+ dy and dx displacements to each Feature in the Frame
222
+
223
+ Args:
224
+ y_flow (NDArray): _description_
225
+ x_flow (NDArray): _description_
226
+ """
227
+ if self._feature_field is None or not self._features:
228
+ raise FeaturesNotFoundError(
229
+ "Features have not been loaded into this Frame. "
230
+ "Cannot assign displacements"
231
+ )
232
+
233
+ self.y_flow, self.x_flow = check_arrays(
234
+ y_flow, x_flow, ndim=2, equal_shape=True
235
+ )
236
+
237
+ # Assign these displacements to each Feature in the Frame using
238
+ # mean of flow field for each grid point spanning the Feature
239
+ for feature_id, feature in self._features.items():
240
+ feature_mask = self._feature_field == feature_id
241
+ feature_dy = np.mean(y_flow[feature_mask])
242
+ feature_dx = np.mean(x_flow[feature_mask])
243
+ feature.dydx = (feature_dy, feature_dx)
244
+
245
+ def get_next_available_feature_id(self) -> int:
246
+ """
247
+ Get the next available feature ID for this Frame.
248
+ Used when new features are created.
249
+
250
+ Returns:
251
+ int: new id
252
+ """
253
+ if self._max_id is None:
254
+ if self._feature_field is not None:
255
+ self._max_id = np.max(self._feature_field).item()
256
+ else:
257
+ self._max_id = 0
258
+ self._max_id += 1
259
+ return self._max_id
260
+
261
+ def promote_provisional_ids(self) -> None:
262
+ """
263
+ Promote "provisional_id" to final "id" for all features.
264
+ """
265
+ # Construct updated features dictionary with new ids as keys
266
+ new_features_dict = {}
267
+
268
+ for feature in self._features.values():
269
+ if feature.provisional_id is not None:
270
+ feature.id = feature.provisional_id
271
+ feature.provisional_id = None
272
+ new_features_dict[feature.id] = feature
273
+
274
+ self._features = new_features_dict
275
+
276
+ def update_fields_using_provisional_ids(self) -> None:
277
+ """
278
+ Update the feature_field to reflect provisional ids.
279
+ """
280
+ if self._feature_field is None:
281
+ raise FeaturesNotFoundError(
282
+ "Feature field is not set. Cannot update using provisional ids."
283
+ )
284
+
285
+ if not self._features:
286
+ raise FeaturesNotFoundError(
287
+ "Features have not been loaded into this Frame. "
288
+ "Cannot update using provisional ids."
289
+ )
290
+
291
+ updated_feature_field = np.zeros_like(self._feature_field)
292
+ updated_lifetime_field = np.zeros_like(self._feature_field)
293
+
294
+ for feature in self._features.values():
295
+ feature_mask = self._feature_field == feature.id
296
+ updated_lifetime_field[feature_mask] = feature.lifetime
297
+ if feature.provisional_id is not None:
298
+ updated_feature_field[feature_mask] = feature.provisional_id
299
+ else:
300
+ updated_feature_field[feature_mask] = feature.id
301
+
302
+ self._feature_field = updated_feature_field
303
+ self._lifetime_field = updated_lifetime_field
304
+
305
+ def get_new_features(self) -> list:
306
+ """
307
+ Get a list of all features in the frame that do not match with a feature from
308
+ the previous frame and has not split from a feature in the previous frame
309
+ """
310
+ if not self._features:
311
+ return []
312
+ return [feature for feature in self._features.values() if feature.is_new()]
313
+
314
+ def get_dissipating_features(self) -> list:
315
+ """
316
+ Get a list of all features in the frame that do not match with a feature
317
+ in the subsequent frame and have not merged with a feature
318
+ in the subsequent frame
319
+ """
320
+ if not self._features:
321
+ return []
322
+ return [
323
+ feature for feature in self._features.values() if feature.is_dissipating()
324
+ ]
325
+
326
+ def get_init_field(self, centroid_only: bool = False) -> NDArray:
327
+ """
328
+ Get a binary field of locations where features are newly initialising, where
329
+ new features are ones that are not matched with a feature in the previous frame,
330
+ and have not split from a feature in the previous frame
331
+ """
332
+ return self.get_field("init", centroid_only)
333
+
334
+ def get_dissipation_field(self, centroid_only: bool = False) -> NDArray:
335
+ """
336
+ Get a binary field of locations where features are dissipating, where
337
+ these are ones that are not matched with a feature in the next frame, and
338
+ do not merge with a feature in the next frame
339
+ """
340
+ return self.get_field("dissipation", centroid_only)
341
+
342
+ def get_field(self, field_type: str, centroid_only: bool = True) -> NDArray:
343
+ """
344
+ Get a binary field of locations where features meet the input requirement,
345
+ as speicified by field type
346
+
347
+ Args:
348
+ field_type (str):
349
+ "init": Get the field of all new features in the frame, where new
350
+ features are ones that are not matched with a feature in the previous
351
+ frame, and have not split from a feature in the previous frame
352
+ "dissipation" Get the fields of all dissipating feature in the frame,
353
+ where these are ones that are not matched with a feature in the next
354
+ frame, and do not merge with a feature in the next frame
355
+ centroid_only (bool, optional):
356
+ Whether the binary output should contain just the feature centroids
357
+ or should span the full feature shape.
358
+ Defaults to True.
359
+
360
+ """
361
+ feature_methods = {
362
+ "init": self.get_new_features,
363
+ "dissipation": self.get_dissipating_features,
364
+ }
365
+ if field_type not in feature_methods:
366
+ raise KeyError(f"field_type must be one of {feature_methods.keys()}")
367
+
368
+ field = np.zeros_like(self._feature_field)
369
+ for feature in feature_methods[field_type]():
370
+ if centroid_only:
371
+ # tuple to ensure correct indexing
372
+ # Round centroid to nearest integer and cast to int
373
+ centroid_coord = tuple(np.rint(feature.centroid).astype(int))
374
+ field[centroid_coord] = 1
375
+ else:
376
+ # Populate field with full size of feature
377
+ init_mask = self._feature_field == feature.id
378
+ field[init_mask] = 1
379
+ return field
380
+
381
+
382
+ class Timeline:
383
+ """
384
+ Object for storing and accessing Frames, stored as a dict of dt.datetime keys
385
+ and Frame values.
386
+ """
387
+
388
+ def __init__(self):
389
+ self.timeline = {}
390
+
391
+ def add_to_timelime(self, frame: Frame) -> None:
392
+ """
393
+ Add the input frame to the timeline, using the frame.get_time() to
394
+ determine the frame time.
395
+ """
396
+ if not isinstance(frame, Frame):
397
+ raise TypeError(f"Expected type Frame, got {type(frame)}")
398
+ if frame.time is None:
399
+ raise ValueError("Frame time is not set. Cannot add to timeline.")
400
+ self.timeline[frame.time] = frame
401
+
402
+ def get_previous_frame(self, current_time: dt.time) -> Frame:
403
+ """
404
+ Finds the frame with the closest time to the input frame, and which
405
+ is in the past.
406
+ """
407
+ if len(self.timeline) == 0:
408
+ raise ValueError("Timeline is empty. No previous frame to return.")
409
+ if len(self.timeline) == 1:
410
+ return None
411
+
412
+ prev_times = [time for time in self.timeline if time < current_time]
413
+ closest_time = max(prev_times) if prev_times else None
414
+ if closest_time is None:
415
+ raise ValueError("No previous frame found in timeline")
416
+ return self.timeline[closest_time]
417
+
418
+ def purge_old_frame(self, max_frames: int = 2) -> None:
419
+ # Remove any frames that aren't needed anymore, as defined by max_frames
420
+ pass
421
+
422
+ def get_timeline(self) -> dict:
423
+ """
424
+ Return the timeline as a dictionary of values, with keys being the validity
425
+ time and values being the frame at that validity time.
426
+ """
427
+ return self.timeline
428
+
429
+ def get_frame(self, time: dt.datetime) -> Frame:
430
+ """
431
+ Get the frame that is valid at the input time. Raises ValueError if frame
432
+ matching the input time is not found.
433
+ """
434
+ if time not in self.timeline:
435
+ raise ValueError(f"No frame found for time {time}")
436
+ return self.timeline[time]
437
+
438
+
439
+ def label_features(
440
+ field: NDArray[np.floating],
441
+ min_area: float,
442
+ threshold: float,
443
+ under_threshold: bool = False,
444
+ connectivity_structure: str = "default",
445
+ ) -> NDArray[np.integer]:
446
+ """
447
+ Label distinct regions in the input field that meet a specified threshold condition.
448
+
449
+ Args:
450
+ field (np.ndarray):
451
+ 2D input array of data to be labelled
452
+ min_area (float):
453
+ Minimum area (in number of grid points) for a region to be considered valid
454
+ threshold (float):
455
+ Threshold value for identifying regions
456
+ under_threshold (bool, optional):
457
+ If True, regions under the threshold are considered;
458
+ if False, regions over the threshold are considered.
459
+ Defaults to False.
460
+ connectivity_structure (str, optional):
461
+ Str defining boolean connectivity for region labelling
462
+ Default is 8-way connectivity, meaning all cardinal AND diagonal neighbours
463
+ that meet the threshold condition are considered part of the same region.
464
+ An alternative arrangement would be "cardinal" 4-way connectivity
465
+ (diagonals omitted), defined as:
466
+ np.array([[0, 1, 0],
467
+ [1, 1, 1],
468
+ [0, 1, 0]])
469
+ See scipy.ndimage.label documentation for more details.
470
+ Defaults to np.ones((3, 3)).
471
+
472
+ Raises:
473
+ TypeError: field must be a numpy ndarray
474
+ ValueError: min_area must be a non-negative number
475
+ ValueError: threshold must be a number"
476
+ TypeError: under_threshold must be a boolean
477
+ ValueError: field must be a 2D array
478
+
479
+ Returns:
480
+ NDArray[np.int_]: 2D Integer field of labelled regions,
481
+ same shape as input field
482
+ """
483
+
484
+ # Check input types
485
+ field = check_arrays(field, ndim=2)
486
+
487
+ # Handle isntance of MaskedArray by filling any masked areas with 0 (background)
488
+ if isinstance(field, np.ma.MaskedArray):
489
+ field = field.filled(fill_value=0)
490
+
491
+ if not isinstance(min_area, (int, float)) or min_area < 0:
492
+ raise ValueError("min_area must be a non-negative number")
493
+ if not isinstance(threshold, (int, float)):
494
+ raise ValueError("threshold must be a number")
495
+ if not isinstance(under_threshold, bool):
496
+ raise TypeError("under_threshold must be a boolean")
497
+
498
+ if connectivity_structure == "default":
499
+ # All cardinal and diagonal points
500
+ connectivity_structure = np.ones((3, 3))
501
+ else:
502
+ raise NotImplementedError("Only default connectivity supported")
503
+
504
+ # Construct feature field using threshold and threshold condition
505
+ # Grid points meeting the condition are set to 1, others to 0s
506
+ if under_threshold:
507
+ feature_field = np.where(field < threshold, 1, 0)
508
+ else:
509
+ feature_field = np.where(field > threshold, 1, 0)
510
+
511
+ # Identify and label distinct regions in the feature field
512
+ id_regions, num_ids = ndimage.label(feature_field, structure=connectivity_structure)
513
+
514
+ # Any regions smaller than the min_area are removed from the feature field
515
+ # before re-running feature labelling
516
+ id_sizes = np.array(ndimage.sum(feature_field, id_regions, range(num_ids + 1)))
517
+ area_mask = id_sizes < min_area
518
+ feature_field[area_mask[id_regions]] = 0
519
+ id_regions, num_ids = ndimage.label(feature_field, structure=connectivity_structure)
520
+
521
+ return id_regions