byotrack 1.2.0.dev1__tar.gz → 1.2.0.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/PKG-INFO +1 -1
  2. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/__init__.py +1 -1
  3. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/features_extractor.py +65 -1
  4. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/frame_by_frame/base.py +329 -25
  5. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/frame_by_frame/kalman_linker.py +21 -7
  6. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/frame_by_frame/koft.py +22 -10
  7. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/frame_by_frame/nearest_neighbor.py +21 -8
  8. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack.egg-info/PKG-INFO +1 -1
  9. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/LICENSE +0 -0
  10. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/README.md +0 -0
  11. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/__init__.py +0 -0
  12. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/detector/__init__.py +0 -0
  13. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/detector/detections.py +0 -0
  14. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/detector/detector.py +0 -0
  15. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/linker.py +0 -0
  16. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/optical_flow/__init__.py +0 -0
  17. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/optical_flow/optical_flow.py +0 -0
  18. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/parameters.py +0 -0
  19. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/refiner.py +0 -0
  20. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/tracker.py +0 -0
  21. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/api/tracks.py +0 -0
  22. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/dataset/__init__.py +0 -0
  23. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/dataset/ctc.py +0 -0
  24. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/example_data.py +0 -0
  25. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/fiji/__init__.py +0 -0
  26. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/fiji/io.py +0 -0
  27. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/fiji/run.py +0 -0
  28. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/icy/__init__.py +0 -0
  29. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/icy/io.py +0 -0
  30. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/icy/run.py +0 -0
  31. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/__init__.py +0 -0
  32. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/detector/__init__.py +0 -0
  33. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/detector/stardist.py +0 -0
  34. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/detector/wavelet.py +0 -0
  35. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/__init__.py +0 -0
  36. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/frame_by_frame/__init__.py +0 -0
  37. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/frame_by_frame/greedy_lap.py +0 -0
  38. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/icy_emht/__init__.py +0 -0
  39. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/icy_emht/emht_protocol.xml +0 -0
  40. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/icy_emht/emht_protocol_with_full_specs.xml +0 -0
  41. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/icy_emht/icy_emht.py +0 -0
  42. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/trackmate/__init__.py +0 -0
  43. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/trackmate/_trackmate.py +0 -0
  44. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/linker/trackmate/trackmate.py +0 -0
  45. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/optical_flow/__init__.py +0 -0
  46. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/optical_flow/opencv.py +0 -0
  47. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/optical_flow/skimage.py +0 -0
  48. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/refiner/__init__.py +0 -0
  49. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/refiner/cleaner.py +0 -0
  50. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/refiner/interpolater.py +0 -0
  51. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/refiner/propagation.py +0 -0
  52. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/refiner/stitching/__init__.py +0 -0
  53. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/refiner/stitching/dist_stitcher.py +0 -0
  54. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/implementation/refiner/stitching/emc2.py +0 -0
  55. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/metrics/__init__.py +0 -0
  56. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/metrics/ctc.py +0 -0
  57. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/py.typed +0 -0
  58. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/utils.py +0 -0
  59. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/video/__init__.py +0 -0
  60. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/video/reader.py +0 -0
  61. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/video/transforms.py +0 -0
  62. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/video/video.py +0 -0
  63. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack/visualize.py +0 -0
  64. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack.egg-info/SOURCES.txt +0 -0
  65. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack.egg-info/dependency_links.txt +0 -0
  66. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack.egg-info/requires.txt +0 -0
  67. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/byotrack.egg-info/top_level.txt +0 -0
  68. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/pyproject.toml +0 -0
  69. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/setup.cfg +0 -0
  70. {byotrack-1.2.0.dev1 → byotrack-1.2.0.dev2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: byotrack
3
- Version: 1.2.0.dev1
3
+ Version: 1.2.0.dev2
4
4
  Summary: Biological particle tracking with Python
5
5
  Home-page: https://github.com/raphaelreme/byotrack
6
6
  Author: Raphael Reme
@@ -83,4 +83,4 @@ from byotrack.api.tracks import Track
83
83
  from byotrack.video import Video, VideoTransformConfig
84
84
 
85
85
 
86
- __version__ = "1.2.0.dev1"
86
+ __version__ = "1.2.0.dev2"
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
4
4
  from typing import Iterable
5
5
 
6
6
  import numpy as np
7
+ import numba # type: ignore
7
8
  import torch
8
9
 
9
10
  import byotrack # pylint: disable=cyclic-import
@@ -56,4 +57,67 @@ class MultiFeaturesExtractor(FeaturesExtractor):
56
57
  return torch.cat(features, dim=-1)
57
58
 
58
59
 
59
- # TODO: Add some examples
60
+ class MassExtractor(FeaturesExtractor):
61
+ """Extract the mass of each detection (number of pixels)"""
62
+
63
+ def __call__(self, frame: np.ndarray, detections: byotrack.Detections):
64
+ torch.tensor(compute_mass(detections.segmentation.numpy()), dtype=torch.float32)
65
+
66
+
67
+ class IntensityExtractor(FeaturesExtractor):
68
+ """Extract the sum of intensities of each detection"""
69
+
70
+ def __call__(self, frame: np.ndarray, detections: byotrack.Detections):
71
+ torch.tensor(compute_intensity(detections.segmentation.numpy(), frame.sum(axis=-1)), dtype=torch.float32)
72
+
73
+
74
+ @numba.njit
75
+ def compute_mass(segmentation: np.ndarray) -> np.ndarray:
76
+ """Extract the number of pixels of each detection
77
+
78
+ Args:
79
+ segmentation (np.ndarray): Segmentation mask
80
+
81
+ Returns:
82
+ np.ndarray: Mass for each object
83
+
84
+ """
85
+ n = segmentation.max()
86
+ mass = np.zeros(n, dtype=np.uint)
87
+
88
+ # Ravel in 1D
89
+ segmentation = segmentation.reshape(-1)
90
+
91
+ for i in range(segmentation.shape[0]):
92
+ instance = segmentation[i] - 1
93
+ if instance != -1:
94
+ mass[instance] += 1
95
+
96
+ return mass
97
+
98
+
99
+ @numba.njit
100
+ def compute_intensity(segmentation: np.ndarray, frame: np.ndarray) -> np.ndarray:
101
+ """Extract the cumulated intensity of each detection
102
+
103
+ Args:
104
+ segmentation (np.ndarray): Segmentation mask
105
+ frame (np.ndarray): Video frame (should have the same number of pixels than segmentation)
106
+
107
+ Returns:
108
+ np.ndarray: Sum of intensity for each object
109
+
110
+ """
111
+ n = segmentation.max()
112
+ intensity = np.zeros(n, dtype=frame.dtype)
113
+
114
+ # Ravel in 1D
115
+ segmentation = segmentation.reshape(-1)
116
+ frame = frame.reshape(-1)
117
+
118
+ for i in range(segmentation.shape[0]):
119
+ instance = segmentation[i] - 1
120
+ if instance != -1:
121
+ intensity[instance] += frame[i]
122
+
123
+ return intensity
@@ -1,6 +1,6 @@
1
1
  from abc import abstractmethod
2
2
  import dataclasses
3
- from typing import List, Optional, Tuple, Union
3
+ from typing import Dict, List, Optional, Tuple, Union
4
4
  import warnings
5
5
 
6
6
  import enum
@@ -67,7 +67,7 @@ class AssociationMethod(enum.Enum):
67
67
  return torch.tensor(greedy_assignment_solver(cost.numpy(), eta).astype(np.int32))
68
68
 
69
69
 
70
- class TrackHandler:
70
+ class TrackHandler: # pylint: disable=too-many-instance-attributes
71
71
  """Handle a track during the tracking procedure
72
72
 
73
73
  It accumulates the track data at each new association and store the optional motion model data.
@@ -91,6 +91,9 @@ class TrackHandler:
91
91
  detection_ids (List[int]): Identifiers of the associated detection (-1 if None)
92
92
  track_ids (List[int]): Index of the track at each frame in the `linker.active_tracks` list.
93
93
  It allows the linker to store data as tensor and be able to rebuild tracks at the end.
94
+ merge_id (int): Identifier to an optional merged track handler (See `Tracks.merge_id`)
95
+ parent_id (int): Identifier to an optional parent track handler (See `Tracks.parent_id`)
96
+ is_split (bool): Just to know if the track splits
94
97
 
95
98
  """
96
99
 
@@ -122,8 +125,14 @@ class TrackHandler:
122
125
  self.last_association = 0
123
126
  self.detection_ids: List[int] = []
124
127
  self.track_ids: List[int] = []
128
+ self.merge_id = -1
129
+ self.parent_id = -1
130
+ self.is_split = False
125
131
 
126
132
  def __len__(self) -> int:
133
+ if self.merge_id != -1 or self.is_split:
134
+ return len(self.detection_ids) # Last points counts
135
+
127
136
  return len(self.detection_ids) - self.last_association
128
137
 
129
138
  def is_active(self) -> bool:
@@ -210,13 +219,16 @@ class OnlineFlowExtractor:
210
219
  class FrameByFrameLinkerParameters: # pylint: disable=too-many-instance-attributes
211
220
  """Parameters of the abstract FrameByFrameLinker
212
221
 
222
+ Note:
223
+ The merging and splitting features is still experimental.
224
+
213
225
  Attributes:
214
226
  association_threshold (float): This is the main hyperparameter, it defines the threshold on the distance used
215
227
  not to link tracks with detections. It prevents to link with false positive detections.
216
228
  Default: 5 pixels
217
- n_valid (int): Number of frames with a correct association required to validate the track at its creation.
229
+ n_valid (int): Number associated detections required to validate the track after its creation.
218
230
  Default: 3
219
- n_gap (int): Number of frames with no association before the track termination.
231
+ n_gap (int): Number of consecutive frames without association before the track termination.
220
232
  Default: 3
221
233
  association_method (AssociationMethod): The frame-by-frame association to use. See `AssociationMethod`.
222
234
  It can be provided as a string. (Choice: GREEDY, OPT_HARD, OPT_SMOOTH)
@@ -224,6 +236,12 @@ class FrameByFrameLinkerParameters: # pylint: disable=too-many-instance-attribu
224
236
  anisotropy (Tuple[float, float, float]): Anisotropy of images (Ratio of the pixel sizes
225
237
  for each axis, depth first). This will be used to scale distances.
226
238
  Default: (1., 1., 1.)
239
+ split_factor (float): Allow splitting of tracks, using a second association step.
240
+ The association threshold in this case is `split_factor * association_threshold`.
241
+ Default: 0.0 (No splits)
242
+ merge_factor (float): Allow merging of tracks, using a second association step.
243
+ The association threshold in this case is `merge_factor * association_threshold`.
244
+ Default: 0.0 (No merges)
227
245
 
228
246
  """
229
247
 
@@ -235,6 +253,8 @@ class FrameByFrameLinkerParameters: # pylint: disable=too-many-instance-attribu
235
253
  n_gap=3,
236
254
  association_method: Union[str, AssociationMethod] = AssociationMethod.OPT_SMOOTH,
237
255
  anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0),
256
+ split_factor: float = 0.0,
257
+ merge_factor: float = 0.0,
238
258
  ):
239
259
  self.association_threshold = association_threshold
240
260
  self.n_valid = n_valid
@@ -245,24 +265,32 @@ class FrameByFrameLinkerParameters: # pylint: disable=too-many-instance-attribu
245
265
  else AssociationMethod[association_method.upper()]
246
266
  )
247
267
  self.anisotropy = anisotropy
268
+ self.split_factor = split_factor
269
+ self.merge_factor = merge_factor
270
+
271
+ if merge_factor >= 1.0 or split_factor >= 1.0:
272
+ warnings.warn("Merge or split factors should be lower than 1")
248
273
 
249
274
  association_threshold: float = 5.0
250
275
  n_valid: int = 3
251
276
  n_gap: int = 3
252
277
  association_method: AssociationMethod = AssociationMethod.OPT_SMOOTH
253
278
  anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0)
279
+ split_factor: float = 0.0
280
+ merge_factor: float = 0.0
254
281
 
255
282
 
256
283
  class FrameByFrameLinker(byotrack.OnlineLinker):
257
284
  """Links detections online using frame-by-frame association
258
285
 
259
- Abstract class for frame-by-frame linker. It decomposes the update step in 5 parts:
286
+ Abstract class for frame-by-frame linker. It decomposes the update step in 6 parts:
260
287
 
261
- 1. Optional optical flow computations (Handled by this class)
288
+ 1. Optional optical flow computations (Handled by this class with the `optflow` given)
262
289
  2. Motion modeling to predict track positions (`motion_model`)
263
- 3. Track-to-detection cost computation (`cost`)
264
- 4. Solving the linear association problem (Handled by this class)
265
- 5. Post matching update to handle tracks (`post_association`)
290
+ 3. Features extraction (handled by this class with the `features_extractor` given)
291
+ 4. Track-to-detection cost computation (`cost`)
292
+ 5. Solving the linear association problem (handled in `associate`)
293
+ 6. Post matching update to handle tracks (`post_association`)
266
294
 
267
295
  The association relies on the AssociationMethod enum and tracks handling is done with
268
296
  TrackHandler.
@@ -287,6 +315,10 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
287
315
  active_tracks (List[TrackHandler]): Current track handlers
288
316
  all_positions (List[torch.Tensor]): Positions of the active tracks at each seen frames.
289
317
  Using the valid track handlers `track_ids`, it allows the reconstruction of tracks.
318
+ split_links (torch.Tensor): Current split_links
319
+ shape: (L', 2), dtype: int32
320
+ merge_links (torch.Tensor): Current merge_links
321
+ shape: (L'', 2), dtype: int32
290
322
 
291
323
  """
292
324
 
@@ -306,6 +338,8 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
306
338
  self.inactive_tracks: List[TrackHandler] = []
307
339
  self.active_tracks: List[TrackHandler] = []
308
340
  self.all_positions: List[torch.Tensor] = []
341
+ self.split_links = torch.zeros((0, 2), dtype=torch.int32)
342
+ self.merge_links = torch.zeros((0, 2), dtype=torch.int32)
309
343
 
310
344
  def reset(self) -> None:
311
345
  super().reset()
@@ -315,6 +349,8 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
315
349
  self.inactive_tracks = []
316
350
  self.active_tracks = []
317
351
  self.all_positions = []
352
+ self.split_links = torch.zeros((0, 2), dtype=torch.int32)
353
+ self.merge_links = torch.zeros((0, 2), dtype=torch.int32)
318
354
 
319
355
  def collect(self) -> List[byotrack.Track]:
320
356
  tracks = []
@@ -337,6 +373,8 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
337
373
  points,
338
374
  handler.identifier,
339
375
  torch.tensor(handler.detection_ids[: len(handler)], dtype=torch.int32),
376
+ parent_id=handler.parent_id,
377
+ merge_id=handler.merge_id,
340
378
  )
341
379
  )
342
380
  return tracks
@@ -382,6 +420,9 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
382
420
  Finally, it is also responsible to register the position of each active track in `all_positions`
383
421
  for the current time frame.
384
422
 
423
+ See `update_active_tracks` which can be called inside this implementation to handle tracks termination
424
+ and creation.
425
+
385
426
  Args:
386
427
  frame (np.ndarray): The current frame of the video
387
428
  Shape: (H, W, C), dtype: float
@@ -391,7 +432,201 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
391
432
 
392
433
  """
393
434
 
394
- def update_active_tracks(self, links: torch.Tensor) -> torch.Tensor:
435
+ def update_active_tracks( # pylint: disable=too-many-branches,too-many-statements,too-many-locals
436
+ self, links: torch.Tensor, detections: byotrack.Detections
437
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
438
+ """Updates tracks handler and creates new ones for extra detections
439
+
440
+ Tracks that are terminated are stored inside `inactive_tracks` and dropped from `active_tracks`.
441
+ It can be called inside `post_association` to facilitate the code.
442
+
443
+ It also handles merges and splits. In the case of some specific merges, it may change a few links.
444
+ The updated links are returned, with a still_active mask for tracks and a new_track mask
445
+ for detections.
446
+
447
+ Args:
448
+ links (torch.Tensor): The links made between active tracks and the detections
449
+ Shape: (L, 2), dtype: int32
450
+ detections (byotrack.Detections): Detections for the given frame
451
+
452
+ Returns:
453
+ torch.Tensor: Updated links (in case of some specific merges)
454
+ Shape: (L, 2), dtype: int32
455
+ torch.Tensor: Boolean tensor indicating True for still active tracks
456
+ Shape: (N_tracks), dtype: bool
457
+ torch.Tensor: Boolean tensor indicating True for newly created tracks from detections
458
+ Shape: (N_dets), dtype: bool
459
+
460
+ """
461
+ if self.split_links.shape[0] + self.merge_links.shape[0] == 0: # Fall back to old simpler version
462
+ return links, self._update_active_tracks(links), self._handle_extra_detections(detections, links)
463
+
464
+ # Ugly to handle merge and splits smoothly... If you do not care about that, read only the old simpler version
465
+ # Or even use offline split and merge strategies.
466
+
467
+ # Find unmatched measures
468
+ unmatched = torch.full((len(detections),), True)
469
+ unmatched[links[:, 1]] = False
470
+
471
+ # Create new tracks from unmatched measures
472
+ new_tracks: List[TrackHandler] = []
473
+ for i in torch.arange(len(detections))[unmatched].tolist():
474
+ track = TrackHandler(
475
+ self.specs.n_valid,
476
+ self.specs.n_gap,
477
+ self.frame_id,
478
+ len(self.inactive_tracks) + len(self.active_tracks) + len(new_tracks),
479
+ )
480
+ track.update(self.frame_id, i)
481
+ new_tracks.append(track)
482
+
483
+ # Lots of useful identifiers mapping
484
+ det_to_new_track = torch.full((len(detections),), -1, dtype=torch.int32)
485
+ det_to_new_track[torch.arange(len(detections))[unmatched]] = torch.arange(len(new_tracks), dtype=torch.int32)
486
+
487
+ track_to_det = torch.full((len(self.active_tracks),), -1, dtype=torch.int32)
488
+ track_to_det[links[:, 0]] = links[:, 1]
489
+ det_to_track = torch.full((len(detections),), -1, dtype=torch.int32)
490
+ det_to_track[links[:, 1]] = links[:, 0]
491
+
492
+ track_to_det_split = torch.full((len(self.active_tracks),), -1, dtype=torch.int32)
493
+ track_to_det_split[self.split_links[:, 0]] = self.split_links[:, 1]
494
+
495
+ track_to_det_merge = torch.full((len(self.active_tracks),), -1, dtype=torch.int32)
496
+ track_to_det_merge[self.merge_links[:, 0]] = self.merge_links[:, 1]
497
+ det_to_track_merge = torch.full((len(detections),), -1, dtype=torch.int32)
498
+ det_to_track_merge[self.merge_links[:, 1]] = self.merge_links[:, 0]
499
+
500
+ # Update active tracks (for merges and splits, they are still active but replaced by a new handler)
501
+ active_mask = torch.full((len(self.active_tracks),), False)
502
+ still_active: List[TrackHandler] = []
503
+ merges_to_ref: List[TrackHandler] = []
504
+ det_to_merge_id: Dict[int, int] = {}
505
+ identifier = len(self.inactive_tracks) + len(self.active_tracks) + len(new_tracks)
506
+
507
+ for i, track in enumerate(self.active_tracks):
508
+ if track_to_det_split[i] != -1: # The track splits
509
+ # For the corner case of a hypothetical track, it is now valid (afterall we found 2 detections for it)
510
+ track.track_state = TrackHandler.TrackState.FINISHED
511
+ track.is_split = True
512
+ self.inactive_tracks.append(track)
513
+
514
+ # Replace by a new VALID handler
515
+ other = TrackHandler(self.specs.n_valid, self.specs.n_gap, self.frame_id, identifier)
516
+ other.track_state = TrackHandler.TrackState.VALID
517
+ identifier += 1
518
+
519
+ # Set parent for both new tracks
520
+ other.parent_id = track.identifier
521
+ new_tracks[det_to_new_track[track_to_det_split[i]]].parent_id = track.identifier
522
+
523
+ # Swap other with track (the splitted one is finished)
524
+ track = other
525
+
526
+ if track_to_det[i] == -1 and track_to_det_merge[i] != -1: # The track is the second branch of a merge
527
+ assert det_to_track[track_to_det_merge[i]] != -1
528
+ # Get the first branch
529
+ other = self.active_tracks[det_to_track[track_to_det_merge[i]]]
530
+
531
+ # If hypothetical, nothing to do, we do not merge with hypothetical tracks
532
+ if track.track_state != TrackHandler.TrackState.HYPOTHETICAL:
533
+ if other.track_state == TrackHandler.TrackState.HYPOTHETICAL:
534
+ # If other is hypothetical, we do not merge but we swap the link
535
+ # This invalidate some mappings for both 'track' and 'other', but it is fine.
536
+ # If 2nd branch is first to be executed this way, the 1st branch will not be executed
537
+ track_to_det[det_to_track[track_to_det_merge[i]]] = -1
538
+ track_to_det[i] = track_to_det_merge[i]
539
+ else: # Merge (as this is the second track, the track is just dropped)
540
+ track.merge_id = int(track_to_det_merge[i]) # This is not merge id yet
541
+ merges_to_ref.append(track) # It will be updated once all tracks has been processed
542
+
543
+ track.track_state = TrackHandler.TrackState.FINISHED
544
+ self.inactive_tracks.append(track)
545
+ continue # It does itw own specific update, the track cannot be updated any more.
546
+
547
+ if track_to_det[i] != -1 and det_to_track_merge[track_to_det[i]] != -1: # First branch of a merge
548
+ other = self.active_tracks[det_to_track_merge[track_to_det[i]]]
549
+
550
+ # If other is hypothetical, nothing to do, we do not merge with hypothetical tracks
551
+ if other.track_state != TrackHandler.TrackState.HYPOTHETICAL:
552
+ if track.track_state == TrackHandler.TrackState.HYPOTHETICAL:
553
+ # If track is hypothetical, we do not merge but we swap the link
554
+ # If 1st branch is first to be executed this way, the 2nd branch will not be executed
555
+ track_to_det[det_to_track_merge[track_to_det[i]]] = track_to_det[i]
556
+ track_to_det[i] = -1
557
+ else: # Merge: Create a new track and stop the former one
558
+ # Replace by a new VALID handler
559
+ other = TrackHandler(self.specs.n_valid, self.specs.n_gap, self.frame_id, identifier)
560
+ other.track_state = TrackHandler.TrackState.VALID
561
+ identifier += 1
562
+
563
+ track.merge_id = other.identifier
564
+ det_to_merge_id[int(track_to_det[i])] = other.identifier
565
+
566
+ # Terminate old track
567
+ track.track_state = TrackHandler.TrackState.FINISHED
568
+ self.inactive_tracks.append(track)
569
+
570
+ # Swap other with track (the merged one is finished)
571
+ track = other # Will be updated and kept in active
572
+
573
+ # Update track (classical link/no link, or a newly created one at a split/merge event)
574
+ track.update(self.frame_id, int(track_to_det[i].item()))
575
+
576
+ # Check if track is still active
577
+ if track.is_active():
578
+ still_active.append(track)
579
+ active_mask[i] = True
580
+ elif track.track_state == TrackHandler.TrackState.FINISHED or self.save_all:
581
+ self.inactive_tracks.append(track)
582
+ elif track.track_state == TrackHandler.TrackState.INVALID and track.parent_id != -1:
583
+ # We have to undo the splitting if a splitted track is not validated
584
+ # Let's find the finished parent track handler and the valid other child handler
585
+ # and concatenate back the data in a single valid track handler
586
+ # This is slow, this should not occurs often. To be really faster it requires other data structures
587
+ # To be changed if it is necessary
588
+ parent = [other for other in self.inactive_tracks if other.identifier == track.parent_id][0]
589
+ child = [
590
+ other for other in self.active_tracks + self.inactive_tracks if other.parent_id == track.parent_id
591
+ ][0]
592
+
593
+ self.inactive_tracks.remove(parent)
594
+
595
+ if child.is_active():
596
+ child_index = self.active_tracks.index(child)
597
+ else:
598
+ child_index = self.inactive_tracks.index(child)
599
+
600
+ concatenated_handler = TrackHandler(
601
+ self.specs.n_valid, self.specs.n_gap, parent.start, child.identifier
602
+ )
603
+ concatenated_handler.track_state = child.track_state
604
+ concatenated_handler.last_association = child.last_association
605
+ concatenated_handler.parent_id = parent.parent_id
606
+ concatenated_handler.merge_id = child.merge_id
607
+ concatenated_handler.track_ids = parent.track_ids + child.track_ids
608
+ concatenated_handler.detection_ids = parent.detection_ids + child.detection_ids
609
+ concatenated_handler.is_split = child.is_split
610
+
611
+ # Replace child by concatenated
612
+ if concatenated_handler.is_active():
613
+ self.active_tracks[child_index] = concatenated_handler
614
+ else:
615
+ self.inactive_tracks[child_index] = concatenated_handler
616
+
617
+ # Relabel merges
618
+ for track in merges_to_ref:
619
+ track.merge_id = det_to_merge_id[track.merge_id]
620
+
621
+ # Relabel links
622
+ links[:, 0] = torch.arange(len(track_to_det))[track_to_det != -1]
623
+ links[:, 1] = track_to_det[track_to_det != -1]
624
+
625
+ self.active_tracks = still_active + new_tracks
626
+
627
+ return links, active_mask, unmatched
628
+
629
+ def _update_active_tracks(self, links: torch.Tensor) -> torch.Tensor:
395
630
  """Calls `update` for active tracks and return a boolean mask that indicates which track is still active
396
631
 
397
632
  Tracks that are terminated are stored inside `inactive_tracks` and dropped from `active_tracks`.
@@ -424,37 +659,113 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
424
659
 
425
660
  return active_mask
426
661
 
427
- def handle_extra_detections(self, detections: byotrack.Detections, links: torch.Tensor) -> torch.Tensor:
662
+ def _handle_extra_detections(self, detections: byotrack.Detections, links: torch.Tensor) -> torch.Tensor:
428
663
  """Handle extra detections by creating new track handlers
429
664
 
430
- It can be called inside `post_association` to create track handlers from extra detections. It will
431
- return the extra detections positions and ids to be further used by `post_association`.
665
+ It can be called inside `post_association` to create track handlers from extra detections. It
666
+ returns a boolean mask indicating a track creating for each detection.
432
667
 
433
668
  Args:
434
669
  detections (byotrack.Detections): Detections for the given frame
435
670
  links (torch.Tensor): The links made between active tracks and the detections
436
671
  Shape: (L, 2), dtype: int32
437
672
 
673
+ Returns:
674
+ torch.Tensor: Boolean tensor indicating True for newly created tracks from detections
675
+ Shape: (N_dets), dtype: bool
676
+
438
677
  """
439
678
  # Find unmatched measures
440
679
  unmatched = torch.full((len(detections),), True)
441
680
  unmatched[links[:, 1]] = False
442
- unmatched_ids = torch.arange(len(detections))[unmatched]
443
- # unmatched_measures = detections.position[unmatched]
444
681
 
445
682
  # Create a new active track for each unmatched measure
446
- for i in range(unmatched_ids.shape[0]):
683
+ for i in torch.arange(len(detections))[unmatched].tolist():
447
684
  handler = TrackHandler(
448
685
  self.specs.n_valid, self.specs.n_gap, self.frame_id, len(self.inactive_tracks) + len(self.active_tracks)
449
686
  )
450
- handler.update(self.frame_id, int(unmatched_ids[i]))
687
+ handler.update(self.frame_id, i)
451
688
  self.active_tracks.append(handler)
452
689
 
453
690
  return unmatched
454
691
 
692
+ def update_detections(self, detections: byotrack.Detections) -> byotrack.Detections:
693
+ """Optional modification of the currrent detections based on the current state
694
+
695
+ This is called by `update` after the motion modeling but before the cost/association.
696
+
697
+ By default, it does not change anything.
698
+
699
+ Args:
700
+ detections (byotrack.Detections): Detections at the current frame
701
+
702
+ Returns:
703
+ byotrack.Detections: The (optionally modified) detections to use at this current frame
704
+
705
+ """
706
+ return detections
707
+
708
+ def associate(self, frame: np.ndarray, detections: byotrack.Detections) -> torch.Tensor:
709
+ """Produces links between the current tracks and detections
710
+
711
+ Optionnally it handles merges and splits by assocating a second time.
712
+
713
+ Args:
714
+ frame (np.ndarray): Current frame
715
+ detections (byotrack.Detections): Current detections
716
+
717
+ Returns:
718
+ torch.Tensor: Links (i, j)
719
+ Shape: (L, 2), dtype: int32
720
+ """
721
+
722
+ cost, threshold = self.cost(frame, detections)
723
+ links = self.specs.association_method.solve(cost, threshold)
724
+
725
+ if self.specs.merge_factor == 0 and self.specs.split_factor == 0:
726
+ return links # No merge or splits
727
+
728
+ unmatched_detections = torch.full((len(detections),), True)
729
+ unmatched_detections[links[:, 1]] = False
730
+ unmatched_tracks = torch.full((len(self.active_tracks),), True)
731
+ unmatched_tracks[links[:, 0]] = False
732
+ valid_tracks = torch.tensor(
733
+ [(track.track_state == TrackHandler.TrackState.VALID) for track in self.active_tracks], dtype=torch.bool
734
+ )
735
+
736
+ if self.specs.merge_factor > 0:
737
+ # We simply do a 2nd association between unassociated VALID tracks with associated detections
738
+ tracks_mask = unmatched_tracks & valid_tracks
739
+ self.merge_links = self.specs.association_method.solve(
740
+ cost[tracks_mask][:, ~unmatched_detections], threshold * self.specs.merge_factor
741
+ )
742
+
743
+ # Relabel
744
+ self.merge_links[:, 0] = torch.arange(len(self.active_tracks))[tracks_mask][self.merge_links[:, 0]]
745
+ self.merge_links[:, 1] = torch.arange(len(detections))[~unmatched_detections][self.merge_links[:, 1]]
746
+
747
+ if self.specs.split_factor > 0:
748
+ # We simply do a 2nd association between associated tracks with unassociated detections
749
+ self.split_links = self.specs.association_method.solve(
750
+ cost[~unmatched_tracks][:, unmatched_detections], threshold * self.specs.split_factor
751
+ )
752
+
753
+ # Relabel
754
+ self.split_links[:, 0] = torch.arange(len(self.active_tracks))[~unmatched_tracks][self.split_links[:, 0]]
755
+ self.split_links[:, 1] = torch.arange(len(detections))[unmatched_detections][self.split_links[:, 1]]
756
+
757
+ return links
758
+
455
759
  def update(self, frame: np.ndarray, detections: byotrack.Detections) -> None:
456
760
  self.frame_id += 1
457
761
 
762
+ # Compute the flow map if optflow given
763
+ if self.optflow is not None:
764
+ self.optflow.update(frame)
765
+
766
+ self.motion_model()
767
+ detections = self.update_detections(detections)
768
+
458
769
  # Compute features if the extractor is given and register inside the detections
459
770
  # Do not recompute the features if some are already registered
460
771
  remove_feats = False
@@ -465,14 +776,7 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
465
776
  remove_feats = True
466
777
  self.features_extractor.register(frame, detections)
467
778
 
468
- # Compute the flow map if optflow given
469
- if self.optflow is not None:
470
- self.optflow.update(frame)
471
-
472
- self.motion_model()
473
- cost, threshold = self.cost(frame, detections)
474
- links = self.specs.association_method.solve(cost, threshold)
475
-
779
+ links = self.associate(frame, detections)
476
780
  self.post_association(frame, detections, links)
477
781
 
478
782
  assert len(self.all_positions[-1]) == len(
@@ -64,6 +64,9 @@ class TrackBuilding(enum.Enum):
64
64
  class KalmanLinkerParameters(FrameByFrameLinkerParameters):
65
65
  """Parameters of KalmanLinker
66
66
 
67
+ Note:
68
+ The merging and splitting features is still experimental.
69
+
67
70
  Attributes:
68
71
  association_threshold (float): This is the main hyperparameter, it defines the threshold on the distance used
69
72
  not to link tracks with detections. It prevents to link with false positive detections.
@@ -80,9 +83,9 @@ class KalmanLinkerParameters(FrameByFrameLinkerParameters):
80
83
  kalman_order (int): Order of the Kalman filter to use.
81
84
  0 for brownian motions, 1 for directed brownian motion, 2 for accelerated brownian motions, etc...
82
85
  Default: 1
83
- n_valid (int): Number of frames with a correct association required to validate the track at its creation.
86
+ n_valid (int): Number associated detections required to validate the track after its creation.
84
87
  Default: 3
85
- n_gap (int): Number of frames with no association before the track termination.
88
+ n_gap (int): Number of consecutive frames without association before the track termination.
86
89
  Default: 3
87
90
  association_method (AssociationMethod): The frame-by-frame association to use. See `AssociationMethod`.
88
91
  It can be provided as a string. (Choice: GREEDY, OPT_HARD, OPT_SMOOTH)
@@ -99,6 +102,12 @@ class KalmanLinkerParameters(FrameByFrameLinkerParameters):
99
102
  Either from detections, or from filtered/smoothed positions computed by the
100
103
  Kalman filter. See `TrackBuilding`. It can be provided as a string.
101
104
  Default: FILTERED
105
+ split_factor (float): Allow splitting of tracks, using a second association step.
106
+ The association threshold in this case is `split_factor * association_threshold`.
107
+ Default: 0.0 (No splits)
108
+ merge_factor (float): Allow merging of tracks, using a second association step.
109
+ The association threshold in this case is `merge_factor * association_threshold`.
110
+ Default: 0.0 (No merges)
102
111
 
103
112
  """
104
113
 
@@ -115,6 +124,8 @@ class KalmanLinkerParameters(FrameByFrameLinkerParameters):
115
124
  anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0),
116
125
  cost: Union[str, Cost] = Cost.EUCLIDEAN,
117
126
  track_building: Union[str, TrackBuilding] = TrackBuilding.FILTERED,
127
+ split_factor: float = 0.0,
128
+ merge_factor: float = 0.0,
118
129
  ):
119
130
  super().__init__(
120
131
  association_threshold=association_threshold,
@@ -122,6 +133,8 @@ class KalmanLinkerParameters(FrameByFrameLinkerParameters):
122
133
  n_gap=n_gap,
123
134
  anisotropy=anisotropy,
124
135
  association_method=association_method,
136
+ split_factor=split_factor,
137
+ merge_factor=merge_factor,
125
138
  )
126
139
 
127
140
  if isinstance(detection_std, float) and min(anisotropy) != max(anisotropy):
@@ -259,6 +272,8 @@ class KalmanLinker(FrameByFrameLinker):
259
272
  states.mean[handler.start : handler.start + len(handler), i, :dim, 0],
260
273
  handler.identifier,
261
274
  torch.tensor(handler.detection_ids[: len(handler)], dtype=torch.int32),
275
+ merge_id=handler.merge_id,
276
+ parent_id=handler.parent_id,
262
277
  )
263
278
  )
264
279
 
@@ -331,6 +346,9 @@ class KalmanLinker(FrameByFrameLinker):
331
346
  if self.active_states is None or self.kalman_filter is None or self.projections is None:
332
347
  raise RuntimeError("The linker should already be initialized.")
333
348
 
349
+ # Update handlers
350
+ links, active_mask, unmatched = self.update_active_tracks(links, detections)
351
+
334
352
  # Update the state of associated tracks (unassociated tracks keep the predicted state)
335
353
  self.active_states[links[:, 0]] = self.kalman_filter.update(
336
354
  self.active_states[links[:, 0]],
@@ -338,11 +356,7 @@ class KalmanLinker(FrameByFrameLinker):
338
356
  projection=self.projections[links[:, 0]],
339
357
  )
340
358
 
341
- # Update active track handlers
342
- active_mask = self.update_active_tracks(links)
343
-
344
- # Create new track handlers for unmatched detections
345
- unmatched = self.handle_extra_detections(detections, links)
359
+ # Create new states for unmatched measures
346
360
  unmatched_measures = detections.position[unmatched]
347
361
 
348
362
  # Build the initial states for tracks:
@@ -19,6 +19,9 @@ from .kalman_linker import Cost, KalmanLinker, KalmanLinkerParameters, TrackBuil
19
19
  class KOFTLinkerParameters(KalmanLinkerParameters):
20
20
  """Parameters of KOFTLinker
21
21
 
22
+ Note:
23
+ The merging and splitting features is still experimental.
24
+
22
25
  Attributes:
23
26
  association_threshold (float): This is the main hyperparameter, it defines the threshold on the distance used
24
27
  not to link tracks with detections. It prevents to link with false positive detections.
@@ -38,9 +41,9 @@ class KOFTLinkerParameters(KalmanLinkerParameters):
38
41
  kalman_order (int): Order of the Kalman filter to use. 0 is not supported.
39
42
  1 for directed brownian motion, 2 for accelerated brownian motions, etc...
40
43
  Default: 1
41
- n_valid (int): Number of frames with a correct association required to validate the track at its creation.
44
+ n_valid (int): Number associated detections required to validate the track after its creation.
42
45
  Default: 3
43
- n_gap (int): Number of frames with no association before the track termination.
46
+ n_gap (int): Number of consecutive frames without association before the track termination.
44
47
  Default: 3
45
48
  association_method (AssociationMethod): The frame-by-frame association to use. See `AssociationMethod`.
46
49
  It can be provided as a string. (Choice: GREEDY, OPT_HARD, OPT_SMOOTH)
@@ -48,7 +51,7 @@ class KOFTLinkerParameters(KalmanLinkerParameters):
48
51
  anisotropy (Tuple[float, float, float]): Anisotropy of images (Ratio of the pixel sizes
49
52
  for each axis, depth first). This will be used to scale distances. It will only impact
50
53
  EUCLIDEAN[_SQ] costs. For probabilistic cost, anisotropy should be already integrated
51
- in the different std of the kalman filter.
54
+ in the stds of the kalman filter (providing one std for each dimension).
52
55
  Default: (1., 1., 1.)
53
56
  cost_method (CostMethod): The cost method to use. It can be provided as a string.
54
57
  See `CostMethod`. It also indicates what is the correct unit of `association_threshold`.
@@ -57,6 +60,12 @@ class KOFTLinkerParameters(KalmanLinkerParameters):
57
60
  Either from detections, or from filtered/smoothed positions computed by the
58
61
  Kalman filter. See `TrackBuilding`. It can be provided as a string.
59
62
  Default: FILTERED
63
+ split_factor (float): Allow splitting of tracks, using a second association step.
64
+ The association threshold in this case is `split_factor * association_threshold`.
65
+ Default: 0.0 (No splits)
66
+ merge_factor (float): Allow merging of tracks, using a second association step.
67
+ The association threshold in this case is `merge_factor * association_threshold`.
68
+ Default: 0.0 (No merges)
60
69
  extract_flows_on_detections (bool): If True it extracts the optical flow at the detection location if possible.
61
70
  Otherwise it extract the flow from the curent estimate of the track position.
62
71
  Default: False
@@ -66,7 +75,7 @@ class KOFTLinkerParameters(KalmanLinkerParameters):
66
75
 
67
76
  """
68
77
 
69
- def __init__( # pylint: disable=too-many-arguments
78
+ def __init__( # pylint: disable=too-many-arguments, too-many-locals
70
79
  self,
71
80
  association_threshold: float,
72
81
  *,
@@ -80,6 +89,8 @@ class KOFTLinkerParameters(KalmanLinkerParameters):
80
89
  anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0),
81
90
  cost: Union[str, Cost] = Cost.EUCLIDEAN,
82
91
  track_building: Union[str, TrackBuilding] = TrackBuilding.FILTERED,
92
+ split_factor: float = 0.0,
93
+ merge_factor: float = 0.0,
83
94
  extract_flows_on_detections=False,
84
95
  always_measure_velocity=True,
85
96
  ):
@@ -94,6 +105,8 @@ class KOFTLinkerParameters(KalmanLinkerParameters):
94
105
  anisotropy=anisotropy,
95
106
  cost=cost,
96
107
  track_building=track_building,
108
+ split_factor=split_factor,
109
+ merge_factor=merge_factor,
97
110
  )
98
111
 
99
112
  if isinstance(flow_std, float) and min(anisotropy) != max(anisotropy):
@@ -148,9 +161,9 @@ class KOFTLinker(KalmanLinker):
148
161
  ) -> None:
149
162
  super().__init__(specs, optflow, features_extractor, save_all)
150
163
 
164
+ self.optflow: OnlineFlowExtractor
151
165
  self.specs: KOFTLinkerParameters
152
166
  assert self.optflow is not None, "KOFT requires an optical flow algorithm"
153
- self.optflow: OnlineFlowExtractor
154
167
 
155
168
  self.last_detections = byotrack.Detections(data={"position": torch.empty((0, 2))})
156
169
  self.n_initial = 0
@@ -290,6 +303,9 @@ class KOFTLinker(KalmanLinker):
290
303
 
291
304
  self.last_detections = detections # Save detections (May be required)
292
305
 
306
+ # Update handlers
307
+ links, active_mask, unmatched = self.update_active_tracks(links, detections)
308
+
293
309
  # Update the state of associated tracks (unassociated tracks keep the predicted state)
294
310
  self.active_states[links[:, 0]] = self.kalman_filter.update(
295
311
  self.active_states[links[:, 0]],
@@ -303,11 +319,7 @@ class KOFTLinker(KalmanLinker):
303
319
  measurement_noise=self.kalman_filter.measurement_noise[: detections.dim, : detections.dim],
304
320
  )
305
321
 
306
- # Update active track handlers
307
- active_mask = self.update_active_tracks(links)
308
-
309
- # Create new track handlers for unmatched detections
310
- unmatched = self.handle_extra_detections(detections, links)
322
+ # Create new states for unmatched measures
311
323
  unmatched_measures = detections.position[unmatched]
312
324
  self.n_initial = unmatched_measures.shape[0]
313
325
 
@@ -13,17 +13,23 @@ from .base import AssociationMethod, FrameByFrameLinker, FrameByFrameLinkerParam
13
13
  class NearestNeighborParameters(FrameByFrameLinkerParameters):
14
14
  """Parameters of NearestNeighborLinker
15
15
 
16
+ Note:
17
+ The merging and splitting features is still experimental.
18
+
16
19
  Attributes:
17
20
  association_threshold (float): This is the main hyperparameter, it defines the threshold on the distance used
18
21
  not to link tracks with detections. It prevents to link with false positive detections.
19
22
  Default: 5 pixels
20
- n_valid (int): Number of frames with a correct association required to validate the track at its creation.
23
+ n_valid (int): Number associated detections required to validate the track after its creation.
21
24
  Default: 3
22
- n_gap (int): Number of frames with no association before the track termination.
25
+ n_gap (int): Number of consecutive frames without association before the track termination.
23
26
  Default: 3
24
27
  association_method (AssociationMethod): The frame-by-frame association to use. See `AssociationMethod`.
25
28
  It can be provided as a string. (Choice: GREEDY, OPT_HARD, OPT_SMOOTH)
26
29
  Default: OPT_SMOOTH
30
+ anisotropy (Tuple[float, float, float]): Anisotropy of images (Ratio of the pixel sizes
31
+ for each axis, depth first). This will be used to scale distances.
32
+ Default: (1., 1., 1.)
27
33
  ema (float): Optional exponential moving average to reduce detection noise. Detection positions are smoothed
28
34
  using this EMA. Should be smaller than 1. It use: x_{t+1} = ema x_{t} + (1 - ema) det(t)
29
35
  As motion is not modeled, EMA may introduce lag that will hinder tracking. It is more effective with
@@ -35,6 +41,12 @@ class NearestNeighborParameters(FrameByFrameLinkerParameters):
35
41
  ForwardBackward interpolation using the same optical flow: it will produce
36
42
  smoother interpolations.
37
43
  Default: False
44
+ split_factor (float): Allow splitting of tracks, using a second association step.
45
+ The association threshold in this case is `split_factor * association_threshold`.
46
+ Default: 0.0 (No splits)
47
+ merge_factor (float): Allow merging of tracks, using a second association step.
48
+ The association threshold in this case is `merge_factor * association_threshold`.
49
+ Default: 0.0 (No merges)
38
50
 
39
51
  """
40
52
 
@@ -48,6 +60,8 @@ class NearestNeighborParameters(FrameByFrameLinkerParameters):
48
60
  anisotropy: Tuple[float, float, float] = (1.0, 1.0, 1.0),
49
61
  ema=0.0,
50
62
  fill_gap=False,
63
+ split_factor: float = 0.0,
64
+ merge_factor: float = 0.0,
51
65
  ):
52
66
  super().__init__( # pylint: disable=duplicate-code
53
67
  association_threshold=association_threshold,
@@ -55,6 +69,8 @@ class NearestNeighborParameters(FrameByFrameLinkerParameters):
55
69
  n_gap=n_gap,
56
70
  association_method=association_method,
57
71
  anisotropy=anisotropy,
72
+ split_factor=split_factor,
73
+ merge_factor=merge_factor,
58
74
  )
59
75
  self.ema = ema
60
76
  self.fill_gap = fill_gap
@@ -123,18 +139,15 @@ class NearestNeighborLinker(FrameByFrameLinker):
123
139
  if self.active_positions is None:
124
140
  self.active_positions = torch.empty((0, detections.position.shape[1]))
125
141
 
142
+ # Update handlers
143
+ links, active_mask, unmatched = self.update_active_tracks(links, detections)
144
+
126
145
  # Update tracks positions with detections
127
146
  # Optionally using an EMA to reduce detections noise
128
147
  self.active_positions[links[:, 0]] -= (1.0 - self.specs.ema) * (
129
148
  self.active_positions[links[:, 0]] - detections.position[links[:, 1]]
130
149
  )
131
150
 
132
- # Update active track handlers
133
- active_mask = self.update_active_tracks(links)
134
-
135
- # Create new track handlers for unmatched detections
136
- unmatched = self.handle_extra_detections(detections, links)
137
-
138
151
  # Merge still active positions and new ones
139
152
  self.active_positions = torch.cat((self.active_positions[active_mask], detections.position[unmatched]))
140
153
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: byotrack
3
- Version: 1.2.0.dev1
3
+ Version: 1.2.0.dev2
4
4
  Summary: Biological particle tracking with Python
5
5
  Home-page: https://github.com/raphaelreme/byotrack
6
6
  Author: Raphael Reme
File without changes
File without changes
File without changes
File without changes