byotrack 1.2.0.dev2__tar.gz → 1.2.0.dev3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/PKG-INFO +1 -1
  2. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/__init__.py +1 -1
  3. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/detector/detections.py +46 -0
  4. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/features_extractor.py +1 -26
  5. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/dataset/ctc.py +41 -4
  6. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/fiji/run.py +14 -3
  7. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/frame_by_frame/base.py +87 -83
  8. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/frame_by_frame/kalman_linker.py +6 -9
  9. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/frame_by_frame/koft.py +13 -14
  10. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/frame_by_frame/nearest_neighbor.py +6 -7
  11. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/trackmate/trackmate.py +1 -2
  12. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/metrics/ctc.py +152 -14
  13. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack.egg-info/PKG-INFO +1 -1
  14. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/LICENSE +0 -0
  15. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/README.md +0 -0
  16. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/__init__.py +0 -0
  17. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/detector/__init__.py +0 -0
  18. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/detector/detector.py +0 -0
  19. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/linker.py +0 -0
  20. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/optical_flow/__init__.py +0 -0
  21. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/optical_flow/optical_flow.py +0 -0
  22. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/parameters.py +0 -0
  23. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/refiner.py +0 -0
  24. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/tracker.py +0 -0
  25. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/api/tracks.py +0 -0
  26. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/dataset/__init__.py +0 -0
  27. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/example_data.py +0 -0
  28. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/fiji/__init__.py +0 -0
  29. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/fiji/io.py +0 -0
  30. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/icy/__init__.py +0 -0
  31. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/icy/io.py +0 -0
  32. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/icy/run.py +0 -0
  33. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/__init__.py +0 -0
  34. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/detector/__init__.py +0 -0
  35. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/detector/stardist.py +0 -0
  36. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/detector/wavelet.py +0 -0
  37. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/__init__.py +0 -0
  38. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/frame_by_frame/__init__.py +0 -0
  39. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/frame_by_frame/greedy_lap.py +0 -0
  40. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/icy_emht/__init__.py +0 -0
  41. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/icy_emht/emht_protocol.xml +0 -0
  42. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/icy_emht/emht_protocol_with_full_specs.xml +0 -0
  43. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/icy_emht/icy_emht.py +0 -0
  44. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/trackmate/__init__.py +0 -0
  45. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/linker/trackmate/_trackmate.py +0 -0
  46. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/optical_flow/__init__.py +0 -0
  47. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/optical_flow/opencv.py +0 -0
  48. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/optical_flow/skimage.py +0 -0
  49. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/refiner/__init__.py +0 -0
  50. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/refiner/cleaner.py +0 -0
  51. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/refiner/interpolater.py +0 -0
  52. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/refiner/propagation.py +0 -0
  53. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/refiner/stitching/__init__.py +0 -0
  54. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/refiner/stitching/dist_stitcher.py +0 -0
  55. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/implementation/refiner/stitching/emc2.py +0 -0
  56. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/metrics/__init__.py +0 -0
  57. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/py.typed +0 -0
  58. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/utils.py +0 -0
  59. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/video/__init__.py +0 -0
  60. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/video/reader.py +0 -0
  61. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/video/transforms.py +0 -0
  62. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/video/video.py +0 -0
  63. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack/visualize.py +0 -0
  64. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack.egg-info/SOURCES.txt +0 -0
  65. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack.egg-info/dependency_links.txt +0 -0
  66. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack.egg-info/requires.txt +0 -0
  67. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/byotrack.egg-info/top_level.txt +0 -0
  68. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/pyproject.toml +0 -0
  69. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/setup.cfg +0 -0
  70. {byotrack-1.2.0.dev2 → byotrack-1.2.0.dev3}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: byotrack
3
- Version: 1.2.0.dev2
3
+ Version: 1.2.0.dev3
4
4
  Summary: Biological particle tracking with Python
5
5
  Home-page: https://github.com/raphaelreme/byotrack
6
6
  Author: Raphael Reme
@@ -83,4 +83,4 @@ from byotrack.api.tracks import Track
83
83
  from byotrack.video import Video, VideoTransformConfig
84
84
 
85
85
 
86
- __version__ = "1.2.0.dev2"
86
+ __version__ = "1.2.0.dev3"
@@ -36,6 +36,32 @@ def _check_confidence(confidence: torch.Tensor) -> None:
36
36
  assert confidence.dtype is torch.float32
37
37
 
38
38
 
39
+ @numba.njit
40
+ def _compute_mass(segmentation: np.ndarray) -> np.ndarray:
41
+ """Extract the number of pixels of each object in the segmentation
42
+
43
+ Args:
44
+ segmentation (np.ndarray): Segmentation mask
45
+
46
+ Returns:
47
+ np.ndarray: Mass for each object
48
+ Shape (n,), dtype: int32
49
+
50
+ """
51
+ n = segmentation.max()
52
+ mass = np.zeros(n, dtype=np.int32)
53
+
54
+ # Ravel in 1D
55
+ segmentation = segmentation.reshape(-1)
56
+
57
+ for i in range(segmentation.shape[0]):
58
+ instance = segmentation[i] - 1
59
+ if instance != -1:
60
+ mass[instance] += 1
61
+
62
+ return mass
63
+
64
+
39
65
  @numba.njit(parallel=False)
40
66
  def _position_from_segmentation(segmentation: np.ndarray) -> np.ndarray:
41
67
  """Return the center (mean) of each instance in the segmentation"""
@@ -278,6 +304,8 @@ class Detections:
278
304
  Shape: ([D, ]H, W), dtype: int32
279
305
  confidence (torch.Tensor): Confidence for each instance
280
306
  Shape: (N,), dtype: float32
307
+ mass (torch.Tensor): Size of each object in pixel, inferred from the data.
308
+ Shape: (N,), dtype: int32
281
309
  use_median_position (bool): Use median instead of mean to compute positions from segmentation.
282
310
  Default: True (Usually more robust)
283
311
 
@@ -360,6 +388,15 @@ class Detections:
360
388
 
361
389
  return confidence
362
390
 
391
+ @property
392
+ def mass(self) -> torch.Tensor:
393
+ mass = self.data.get("mass", self._lazy_extrapolated_data.get("mass"))
394
+ if mass is None:
395
+ mass = self._extrapolate_mass()
396
+ self._lazy_extrapolated_data["confidence"] = mass
397
+
398
+ return mass
399
+
363
400
  @property
364
401
  def use_median_position(self) -> bool:
365
402
  return self._use_median_position
@@ -443,6 +480,15 @@ class Detections:
443
480
  """Extrapolate confidence"""
444
481
  return torch.ones(self.length)
445
482
 
483
+ def _extrapolate_mass(self) -> torch.Tensor:
484
+ if "segmentation" in self.data:
485
+ return torch.tensor(_compute_mass(self.data["segmentation"].numpy()), dtype=torch.int32)
486
+
487
+ if "bbox" in self.data:
488
+ return self.data["bbox"][:, self.dim :].prod(dim=-1)
489
+
490
+ return torch.ones(self.length, dtype=torch.int32)
491
+
446
492
  def __len__(self) -> int:
447
493
  return self.length
448
494
 
@@ -61,7 +61,7 @@ class MassExtractor(FeaturesExtractor):
61
61
  """Extract the mass of each detection (number of pixels)"""
62
62
 
63
63
  def __call__(self, frame: np.ndarray, detections: byotrack.Detections):
64
- torch.tensor(compute_mass(detections.segmentation.numpy()), dtype=torch.float32)
64
+ return detections.mass
65
65
 
66
66
 
67
67
  class IntensityExtractor(FeaturesExtractor):
@@ -71,31 +71,6 @@ class IntensityExtractor(FeaturesExtractor):
71
71
  torch.tensor(compute_intensity(detections.segmentation.numpy(), frame.sum(axis=-1)), dtype=torch.float32)
72
72
 
73
73
 
74
- @numba.njit
75
- def compute_mass(segmentation: np.ndarray) -> np.ndarray:
76
- """Extract the number of pixels of each detection
77
-
78
- Args:
79
- segmentation (np.ndarray): Segmentation mask
80
-
81
- Returns:
82
- np.ndarray: Mass for each object
83
-
84
- """
85
- n = segmentation.max()
86
- mass = np.zeros(n, dtype=np.uint)
87
-
88
- # Ravel in 1D
89
- segmentation = segmentation.reshape(-1)
90
-
91
- for i in range(segmentation.shape[0]):
92
- instance = segmentation[i] - 1
93
- if instance != -1:
94
- mass[instance] += 1
95
-
96
- return mass
97
-
98
-
99
74
  @numba.njit
100
75
  def compute_intensity(segmentation: np.ndarray, frame: np.ndarray) -> np.ndarray:
101
76
  """Extract the cumulated intensity of each detection
@@ -286,6 +286,8 @@ def _fast_disk_2d(
286
286
  positions: np.ndarray,
287
287
  identifiers: np.ndarray,
288
288
  radius: np.ndarray,
289
+ *,
290
+ overwrite=False,
289
291
  ):
290
292
  """Fast inplace drawing of disk in 2D
291
293
 
@@ -300,6 +302,8 @@ def _fast_disk_2d(
300
302
  Shape: (n,), dtype: uint16
301
303
  radius (np.ndarray): Radius of each disk
302
304
  Shape: (n, ), dtype: float32
305
+ overwrite (bool): Overwrite pixels that are already written (!=0)
306
+ Default: False
303
307
 
304
308
  """
305
309
  positions_round = np.round(positions).astype(np.int64)
@@ -315,6 +319,9 @@ def _fast_disk_2d(
315
319
  if not 0 <= i < segmentation.shape[0] or not 0 <= j < segmentation.shape[1]:
316
320
  continue
317
321
 
322
+ if not overwrite and segmentation[i, j] != 0 and best_dist[i, j] == np.inf:
323
+ continue
324
+
318
325
  delta = pos - positions[k]
319
326
  dist = delta @ delta
320
327
 
@@ -325,12 +332,15 @@ def _fast_disk_2d(
325
332
 
326
333
 
327
334
  @numba.njit(parallel=True)
328
- def _fast_disk_3d(
335
+ def _fast_disk_3d( # pylint: disable=too-many-locals
329
336
  segmentation: np.ndarray,
330
337
  bbox: np.ndarray,
331
338
  positions: np.ndarray,
332
339
  identifiers: np.ndarray,
333
340
  radius: np.ndarray,
341
+ *,
342
+ anisoptropy=1.0,
343
+ overwrite=False,
334
344
  ):
335
345
  """Fast inplace drawing of disk in 3D
336
346
 
@@ -345,6 +355,11 @@ def _fast_disk_3d(
345
355
  Shape: (n,), dtype: uint16
346
356
  radius (np.ndarray): Radius of each disk
347
357
  Shape: (n, ), dtype: float32
358
+ anisotropy (float): Relative size of a pixel along the depth dimension
359
+ versus height/width dimensions.
360
+ Default: 1.0
361
+ overwrite (bool): Overwrite pixels that are already written (!=0)
362
+ Default: False
348
363
 
349
364
  """
350
365
  positions_round = np.round(positions).astype(np.int64)
@@ -364,7 +379,11 @@ def _fast_disk_3d(
364
379
  ):
365
380
  continue
366
381
 
382
+ if not overwrite and segmentation[z, i, j] != 0 and best_dist[z, i, j] == np.inf:
383
+ continue
384
+
367
385
  delta = pos - positions[k]
386
+ delta[0] *= anisoptropy # Increase distance in Z by the anisotropy
368
387
  dist = delta @ delta
369
388
 
370
389
  if dist <= radius[k]:
@@ -378,6 +397,9 @@ def draw_disk(
378
397
  positions: np.ndarray,
379
398
  identifiers: np.ndarray,
380
399
  radius: np.ndarray,
400
+ *,
401
+ anisotropy=1.0,
402
+ overwrite=False,
381
403
  ):
382
404
  """Draw disks on the segmentation
383
405
 
@@ -390,21 +412,27 @@ def draw_disk(
390
412
  Shape: (n,), dtype: uint16
391
413
  radius (np.ndarray): Radius of each disk
392
414
  Shape: (n, ), dtype: float32)
415
+ anisotropy (float): Relative size of a pixel along the depth dimension
416
+ versus height/width dimensions.
417
+ Default: 1.0
418
+ overwrite (bool): Overwrite pixels that are already written (!=0)
419
+ Default: False
420
+
393
421
  """
394
422
  # Wrapped to redirect in 2D/3D and numba does not support np.indices in 3.8
395
423
  if segmentation.ndim == 3:
396
424
  thresh = round(radius.max())
397
425
  bbox: np.ndarray = np.indices((thresh * 2 + 1, thresh * 2 + 1, thresh * 2 + 1)).transpose(1, 2, 3, 0) - thresh
398
426
  bbox = bbox.reshape(-1, 3)
399
- _fast_disk_3d(segmentation, bbox, positions, identifiers, radius)
427
+ _fast_disk_3d(segmentation, bbox, positions, identifiers, radius, anisoptropy=anisotropy, overwrite=overwrite)
400
428
  else:
401
429
  thresh = int(round(radius.max()))
402
430
  bbox = np.indices((thresh * 2 + 1, thresh * 2 + 1)).transpose(1, 2, 0) - thresh
403
431
  bbox = bbox.reshape(-1, 2)
404
- _fast_disk_2d(segmentation, bbox, positions, identifiers, radius)
432
+ _fast_disk_2d(segmentation, bbox, positions, identifiers, radius, overwrite=overwrite)
405
433
 
406
434
 
407
- def save_tracks( # pylint: disable=too-many-branches,too-many-locals,too-many-statements
435
+ def save_tracks( # pylint: disable=too-many-branches,too-many-locals,too-many-statements,too-many-arguments
408
436
  path: Union[str, os.PathLike],
409
437
  tracks: Collection[byotrack.Track],
410
438
  detections_sequence: Sequence[byotrack.Detections] = (),
@@ -415,6 +443,8 @@ def save_tracks( # pylint: disable=too-many-branches,too-many-locals,too-many-s
415
443
  last=0,
416
444
  shape: Optional[Tuple[int, ...]] = None,
417
445
  n_digit=4,
446
+ anisotropy=1.0,
447
+ overwrite_detections=False,
418
448
  ):
419
449
  """Save tracks in the CTC format [10]
420
450
 
@@ -453,6 +483,11 @@ def save_tracks( # pylint: disable=too-many-branches,too-many-locals,too-many-s
453
483
  Default: None
454
484
  n_digit (int): Number of digit used to encode time in file names.
455
485
  Default: 4
486
+ anisotropy (float): Relative size of a pixel along the depth dimension
487
+ versus height/width dimensions.
488
+ Default: 1.0
489
+ overwrite_detections (bool): Overwrite the segmentation of objects with disk.
490
+ Default: False (Disk are only drawn on background)
456
491
 
457
492
  """
458
493
  path = pathlib.Path(path)
@@ -517,6 +552,8 @@ def save_tracks( # pylint: disable=too-many-branches,too-many-locals,too-many-s
517
552
  torch.stack(disk_positions).numpy(),
518
553
  np.array(disk_ids, dtype=np.uint16) + 1,
519
554
  np.full(len(disk_ids), default_radius, dtype=np.float32),
555
+ anisotropy=anisotropy,
556
+ overwrite=overwrite_detections,
520
557
  )
521
558
 
522
559
  # Safety checks because CTC is quite restrictive
@@ -10,16 +10,23 @@ class FijiRunner: # pylint: disable=too-few-public-methods
10
10
  Attributes:
11
11
  fiji_path (str | os.PathLike): Path to the fiji executable
12
12
  The executable can be found inside the installation folder of Fiji.
13
- Linux: Fiji.app/ImageJ-<os>.exe
13
+ Linux: Fiji.app/ImageJ-<os>
14
14
  Windows: Fiji.app/ImageJ-<os>.exe
15
15
  MacOs: Fiji.app/Contents/MacOs/ImageJ-<os>
16
+ capture_outputs (bool): Whether to PIPE stderr and stdout into Python
17
+ This will allow you to find the stdout/stderr inside `last_outputs`.
18
+ But outputs are captured, and you do not see them while the scripts run.
19
+ Default: False
20
+ last_outputs (subprocess.CompletedProcess): Outputs of the last subprocess.run
16
21
 
17
22
  """
18
23
 
19
24
  cmd = './{fiji} --ij2 --headless --console --run "{script}" "{args}"'
20
25
 
21
- def __init__(self, fiji_path: Union[str, os.PathLike]) -> None:
26
+ def __init__(self, fiji_path: Union[str, os.PathLike], capture_outputs=False) -> None:
22
27
  self.fiji_path = fiji_path
28
+ self.capture_outputs = capture_outputs
29
+ self.last_outputs: subprocess.CompletedProcess = subprocess.CompletedProcess("", 0)
23
30
 
24
31
  assert os.path.isfile(fiji_path), "Unable to found the given path"
25
32
 
@@ -42,4 +49,8 @@ class FijiRunner: # pylint: disable=too-few-public-methods
42
49
  cmd = cmd[2:] # Strip ./ on windows
43
50
 
44
51
  print("Calling Fiji with:", cmd)
45
- return subprocess.run(cmd, check=True, cwd=os.path.dirname(self.fiji_path), shell=True).returncode
52
+ self.last_outputs = subprocess.run(
53
+ cmd, check=True, cwd=os.path.dirname(self.fiji_path), shell=True, capture_output=self.capture_outputs
54
+ )
55
+
56
+ return self.last_outputs.returncode
@@ -268,7 +268,7 @@ class FrameByFrameLinkerParameters: # pylint: disable=too-many-instance-attribu
268
268
  self.split_factor = split_factor
269
269
  self.merge_factor = merge_factor
270
270
 
271
- if merge_factor >= 1.0 or split_factor >= 1.0:
271
+ if merge_factor > 1.0 or split_factor > 1.0:
272
272
  warnings.warn("Merge or split factors should be lower than 1")
273
273
 
274
274
  association_threshold: float = 5.0
@@ -280,7 +280,7 @@ class FrameByFrameLinkerParameters: # pylint: disable=too-many-instance-attribu
280
280
  merge_factor: float = 0.0
281
281
 
282
282
 
283
- class FrameByFrameLinker(byotrack.OnlineLinker):
283
+ class FrameByFrameLinker(byotrack.OnlineLinker): # pylint: disable=too-many-instance-attributes
284
284
  """Links detections online using frame-by-frame association
285
285
 
286
286
  Abstract class for frame-by-frame linker. It decomposes the update step in 6 parts:
@@ -338,8 +338,13 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
338
338
  self.inactive_tracks: List[TrackHandler] = []
339
339
  self.active_tracks: List[TrackHandler] = []
340
340
  self.all_positions: List[torch.Tensor] = []
341
- self.split_links = torch.zeros((0, 2), dtype=torch.int32)
342
- self.merge_links = torch.zeros((0, 2), dtype=torch.int32)
341
+ self.active_mass = torch.zeros((0,), dtype=torch.float32)
342
+
343
+ # Instantaneous useful quantities
344
+ self._links = torch.zeros((0, 2), dtype=torch.int32)
345
+ self._split_links = torch.zeros((0, 2), dtype=torch.int32)
346
+ self._merge_links = torch.zeros((0, 2), dtype=torch.int32)
347
+ self._unmatched_detections = torch.full((0,), True)
343
348
 
344
349
  def reset(self) -> None:
345
350
  super().reset()
@@ -349,8 +354,11 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
349
354
  self.inactive_tracks = []
350
355
  self.active_tracks = []
351
356
  self.all_positions = []
352
- self.split_links = torch.zeros((0, 2), dtype=torch.int32)
353
- self.merge_links = torch.zeros((0, 2), dtype=torch.int32)
357
+ self.active_mass = torch.zeros((0,), dtype=torch.float32)
358
+ self._links = torch.zeros((0, 2), dtype=torch.int32)
359
+ self._split_links = torch.zeros((0, 2), dtype=torch.int32)
360
+ self._merge_links = torch.zeros((0, 2), dtype=torch.int32)
361
+ self._unmatched_detections = torch.full((0,), True)
354
362
 
355
363
  def collect(self) -> List[byotrack.Track]:
356
364
  tracks = []
@@ -412,65 +420,52 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
412
420
  """
413
421
 
414
422
  @abstractmethod
415
- def post_association(self, frame: np.ndarray, detections: byotrack.Detections, links: torch.Tensor):
416
- """Update the tracks and the internal variables of the tracker
417
-
418
- It should call the `update` method of each active tracks and update any internal model/data.
419
- It should also create new track handlers for each extra detection.
420
- Finally, it is also responsible to register the position of each active track in `all_positions`
421
- for the current time frame.
423
+ def post_association(self, frame: np.ndarray, detections: byotrack.Detections, active_mask: torch.Tensor):
424
+ """Update the internal state of the tracker after `update_active_tracks`
422
425
 
423
- See `update_active_tracks` which can be called inside this implementation to handle tracks termination
424
- and creation.
426
+ It should update any internal model/data. It is also responsible to register the position of each active
427
+ track in `all_positions` for the current time frame.
425
428
 
426
429
  Args:
427
430
  frame (np.ndarray): The current frame of the video
428
431
  Shape: (H, W, C), dtype: float
429
432
  detections (byotrack.Detections): Detections for the given frame
430
- links (torch.Tensor): The links made between active tracks and the detections
431
- Shape: (L, 2), dtype: int32
433
+ active_mask (torch.Tensor): Boolean tensor indicating True for still active tracks
434
+ Shape: (N_tracks), dtype: bool
432
435
 
433
436
  """
434
437
 
435
438
  def update_active_tracks( # pylint: disable=too-many-branches,too-many-statements,too-many-locals
436
- self, links: torch.Tensor, detections: byotrack.Detections
437
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
439
+ self, detections: byotrack.Detections
440
+ ) -> torch.Tensor:
438
441
  """Updates tracks handler and creates new ones for extra detections
439
442
 
440
443
  Tracks that are terminated are stored inside `inactive_tracks` and dropped from `active_tracks`.
441
- It can be called inside `post_association` to facilitate the code.
444
+ It is called by update before `post_association`.
442
445
 
443
446
  It also handles merges and splits. In the case of some specific merges, it may change a few links.
444
447
  The updated links are returned, with a still_active mask for tracks and a new_track mask
445
448
  for detections.
446
449
 
447
450
  Args:
448
- links (torch.Tensor): The links made between active tracks and the detections
449
- Shape: (L, 2), dtype: int32
450
451
  detections (byotrack.Detections): Detections for the given frame
451
452
 
452
453
  Returns:
453
- torch.Tensor: Updated links (in case of some specific merges)
454
- Shape: (L, 2), dtype: int32
455
454
  torch.Tensor: Boolean tensor indicating True for still active tracks
456
455
  Shape: (N_tracks), dtype: bool
457
- torch.Tensor: Boolean tensor indicating True for newly created tracks from detections
458
- Shape: (N_dets), dtype: bool
459
456
 
460
457
  """
461
- if self.split_links.shape[0] + self.merge_links.shape[0] == 0: # Fall back to old simpler version
462
- return links, self._update_active_tracks(links), self._handle_extra_detections(detections, links)
458
+ if self._split_links.shape[0] + self._merge_links.shape[0] == 0: # Fall back to old simpler version
459
+ active_mask = self._update_active_tracks()
460
+ self._handle_extra_detections(detections)
461
+ return active_mask
463
462
 
464
463
  # Ugly to handle merge and splits smoothly... If you do not care about that, read only the old simpler version
465
464
  # Or even use offline split and merge strategies.
466
465
 
467
- # Find unmatched measures
468
- unmatched = torch.full((len(detections),), True)
469
- unmatched[links[:, 1]] = False
470
-
471
466
  # Create new tracks from unmatched measures
472
467
  new_tracks: List[TrackHandler] = []
473
- for i in torch.arange(len(detections))[unmatched].tolist():
468
+ for i in torch.arange(len(detections))[self._unmatched_detections].tolist():
474
469
  track = TrackHandler(
475
470
  self.specs.n_valid,
476
471
  self.specs.n_gap,
@@ -482,20 +477,22 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
482
477
 
483
478
  # Lots of useful identifiers mapping
484
479
  det_to_new_track = torch.full((len(detections),), -1, dtype=torch.int32)
485
- det_to_new_track[torch.arange(len(detections))[unmatched]] = torch.arange(len(new_tracks), dtype=torch.int32)
480
+ det_to_new_track[torch.arange(len(detections))[self._unmatched_detections]] = torch.arange(
481
+ len(new_tracks), dtype=torch.int32
482
+ )
486
483
 
487
484
  track_to_det = torch.full((len(self.active_tracks),), -1, dtype=torch.int32)
488
- track_to_det[links[:, 0]] = links[:, 1]
485
+ track_to_det[self._links[:, 0]] = self._links[:, 1]
489
486
  det_to_track = torch.full((len(detections),), -1, dtype=torch.int32)
490
- det_to_track[links[:, 1]] = links[:, 0]
487
+ det_to_track[self._links[:, 1]] = self._links[:, 0]
491
488
 
492
489
  track_to_det_split = torch.full((len(self.active_tracks),), -1, dtype=torch.int32)
493
- track_to_det_split[self.split_links[:, 0]] = self.split_links[:, 1]
490
+ track_to_det_split[self._split_links[:, 0]] = self._split_links[:, 1]
494
491
 
495
492
  track_to_det_merge = torch.full((len(self.active_tracks),), -1, dtype=torch.int32)
496
- track_to_det_merge[self.merge_links[:, 0]] = self.merge_links[:, 1]
493
+ track_to_det_merge[self._merge_links[:, 0]] = self._merge_links[:, 1]
497
494
  det_to_track_merge = torch.full((len(detections),), -1, dtype=torch.int32)
498
- det_to_track_merge[self.merge_links[:, 1]] = self.merge_links[:, 0]
495
+ det_to_track_merge[self._merge_links[:, 1]] = self._merge_links[:, 0]
499
496
 
500
497
  # Update active tracks (for merges and splits, they are still active but replaced by a new handler)
501
498
  active_mask = torch.full((len(self.active_tracks),), False)
@@ -619,30 +616,24 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
619
616
  track.merge_id = det_to_merge_id[track.merge_id]
620
617
 
621
618
  # Relabel links
622
- links[:, 0] = torch.arange(len(track_to_det))[track_to_det != -1]
623
- links[:, 1] = track_to_det[track_to_det != -1]
619
+ self._links[:, 0] = torch.arange(len(track_to_det))[track_to_det != -1]
620
+ self._links[:, 1] = track_to_det[track_to_det != -1]
624
621
 
625
622
  self.active_tracks = still_active + new_tracks
626
623
 
627
- return links, active_mask, unmatched
624
+ return active_mask
628
625
 
629
- def _update_active_tracks(self, links: torch.Tensor) -> torch.Tensor:
626
+ def _update_active_tracks(self) -> torch.Tensor:
630
627
  """Calls `update` for active tracks and return a boolean mask that indicates which track is still active
631
628
 
632
629
  Tracks that are terminated are stored inside `inactive_tracks` and dropped from `active_tracks`.
633
630
 
634
- It can be called inside `post_association` to simplify the code.
635
-
636
- Args:
637
- links (torch.Tensor): The links made between active tracks and the detections
638
- Shape: (L, 2), dtype: int32
639
-
640
631
  Returns:
641
632
  torch.Tensor: Boolean tensor indicating True for still active tracks
642
633
 
643
634
  """
644
635
  i_to_j = torch.full((len(self.active_tracks),), -1, dtype=torch.int32)
645
- i_to_j[links[:, 0]] = links[:, 1]
636
+ i_to_j[self._links[:, 0]] = self._links[:, 1]
646
637
  active_mask = torch.full((len(self.active_tracks),), False)
647
638
  still_active = []
648
639
  for i, track in enumerate(self.active_tracks):
@@ -659,36 +650,21 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
659
650
 
660
651
  return active_mask
661
652
 
662
- def _handle_extra_detections(self, detections: byotrack.Detections, links: torch.Tensor) -> torch.Tensor:
653
+ def _handle_extra_detections(self, detections: byotrack.Detections):
663
654
  """Handle extra detections by creating new track handlers
664
655
 
665
- It can be called inside `post_association` to create track handlers from extra detections. It
666
- returns a boolean mask indicating a track creating for each detection.
667
-
668
656
  Args:
669
657
  detections (byotrack.Detections): Detections for the given frame
670
- links (torch.Tensor): The links made between active tracks and the detections
671
- Shape: (L, 2), dtype: int32
672
-
673
- Returns:
674
- torch.Tensor: Boolean tensor indicating True for newly created tracks from detections
675
- Shape: (N_dets), dtype: bool
676
658
 
677
659
  """
678
- # Find unmatched measures
679
- unmatched = torch.full((len(detections),), True)
680
- unmatched[links[:, 1]] = False
681
-
682
660
  # Create a new active track for each unmatched measure
683
- for i in torch.arange(len(detections))[unmatched].tolist():
661
+ for i in torch.arange(len(detections))[self._unmatched_detections].tolist():
684
662
  handler = TrackHandler(
685
663
  self.specs.n_valid, self.specs.n_gap, self.frame_id, len(self.inactive_tracks) + len(self.active_tracks)
686
664
  )
687
665
  handler.update(self.frame_id, i)
688
666
  self.active_tracks.append(handler)
689
667
 
690
- return unmatched
691
-
692
668
  def update_detections(self, detections: byotrack.Detections) -> byotrack.Detections:
693
669
  """Optional modification of the currrent detections based on the current state
694
670
 
@@ -720,15 +696,16 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
720
696
  """
721
697
 
722
698
  cost, threshold = self.cost(frame, detections)
723
- links = self.specs.association_method.solve(cost, threshold)
699
+ self._links = self.specs.association_method.solve(cost, threshold)
700
+
701
+ self._unmatched_detections = torch.full((len(detections),), True)
702
+ self._unmatched_detections[self._links[:, 1]] = False
724
703
 
725
704
  if self.specs.merge_factor == 0 and self.specs.split_factor == 0:
726
- return links # No merge or splits
705
+ return self._links # No merge or splits
727
706
 
728
- unmatched_detections = torch.full((len(detections),), True)
729
- unmatched_detections[links[:, 1]] = False
730
707
  unmatched_tracks = torch.full((len(self.active_tracks),), True)
731
- unmatched_tracks[links[:, 0]] = False
708
+ unmatched_tracks[self._links[:, 0]] = False
732
709
  valid_tracks = torch.tensor(
733
710
  [(track.track_state == TrackHandler.TrackState.VALID) for track in self.active_tracks], dtype=torch.bool
734
711
  )
@@ -736,25 +713,45 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
736
713
  if self.specs.merge_factor > 0:
737
714
  # We simply do a 2nd association between unassociated VALID tracks with associated detections
738
715
  tracks_mask = unmatched_tracks & valid_tracks
739
- self.merge_links = self.specs.association_method.solve(
740
- cost[tracks_mask][:, ~unmatched_detections], threshold * self.specs.merge_factor
716
+
717
+ # TODO: Mass factor
718
+ # self.active_mass[tracks_mask]
719
+
720
+ self._merge_links = self.specs.association_method.solve(
721
+ cost[tracks_mask][:, ~self._unmatched_detections], threshold * self.specs.merge_factor
741
722
  )
742
723
 
743
724
  # Relabel
744
- self.merge_links[:, 0] = torch.arange(len(self.active_tracks))[tracks_mask][self.merge_links[:, 0]]
745
- self.merge_links[:, 1] = torch.arange(len(detections))[~unmatched_detections][self.merge_links[:, 1]]
725
+ self._merge_links[:, 0] = torch.arange(len(self.active_tracks))[tracks_mask][self._merge_links[:, 0]]
726
+ self._merge_links[:, 1] = torch.arange(len(detections))[~self._unmatched_detections][
727
+ self._merge_links[:, 1]
728
+ ]
746
729
 
747
730
  if self.specs.split_factor > 0:
748
731
  # We simply do a 2nd association between associated tracks with unassociated detections
749
- self.split_links = self.specs.association_method.solve(
750
- cost[~unmatched_tracks][:, unmatched_detections], threshold * self.specs.split_factor
732
+
733
+ # Split mass factor
734
+ # We increase the distance if the split is not evenly weighted and if it does not sum at the previous mass
735
+ track_mass = self.active_mass[self._links[:, 0]]
736
+ associated_mass = detections.mass[self._links[:, 1]]
737
+ non_associated_mass = detections.mass[self._unmatched_detections]
738
+
739
+ even_factor = torch.maximum(associated_mass[:, None], non_associated_mass[None, :]) / torch.minimum(
740
+ associated_mass[:, None], non_associated_mass[None, :]
741
+ )
742
+ sum_ = associated_mass[:, None] + non_associated_mass[None, :]
743
+ mass_factor = torch.maximum(track_mass[:, None], sum_) / torch.minimum(track_mass[:, None], sum_)
744
+
745
+ self._split_links = self.specs.association_method.solve(
746
+ cost[~unmatched_tracks][:, self._unmatched_detections] * even_factor * mass_factor,
747
+ threshold * self.specs.split_factor,
751
748
  )
752
749
 
753
750
  # Relabel
754
- self.split_links[:, 0] = torch.arange(len(self.active_tracks))[~unmatched_tracks][self.split_links[:, 0]]
755
- self.split_links[:, 1] = torch.arange(len(detections))[unmatched_detections][self.split_links[:, 1]]
751
+ self._split_links[:, 0] = torch.arange(len(self.active_tracks))[~unmatched_tracks][self._split_links[:, 0]]
752
+ self._split_links[:, 1] = torch.arange(len(detections))[self._unmatched_detections][self._split_links[:, 1]]
756
753
 
757
- return links
754
+ return self._links
758
755
 
759
756
  def update(self, frame: np.ndarray, detections: byotrack.Detections) -> None:
760
757
  self.frame_id += 1
@@ -776,8 +773,15 @@ class FrameByFrameLinker(byotrack.OnlineLinker):
776
773
  remove_feats = True
777
774
  self.features_extractor.register(frame, detections)
778
775
 
779
- links = self.associate(frame, detections)
780
- self.post_association(frame, detections, links)
776
+ self.associate(frame, detections)
777
+ active_mask = self.update_active_tracks(detections)
778
+ self.post_association(frame, detections, active_mask)
779
+
780
+ # Handle mass with a fixed ema and concatenate with mass of newly created track
781
+ self.active_mass[self._links[:, 0]] -= (1.0 - 0.8) * (
782
+ self.active_mass[self._links[:, 0]] - detections.mass[self._links[:, 1]]
783
+ )
784
+ self.active_mass = torch.cat((self.active_mass[active_mask], detections.mass[self._unmatched_detections]))
781
785
 
782
786
  assert len(self.all_positions[-1]) == len(
783
787
  self.active_tracks
@@ -342,22 +342,19 @@ class KalmanLinker(FrameByFrameLinker):
342
342
  cost = -self.projections[:, None].log_likelihood(detections.position[None, ..., None])
343
343
  return cost, -torch.log(torch.tensor(self.specs.association_threshold)).item()
344
344
 
345
- def post_association(self, _: np.ndarray, detections: byotrack.Detections, links: torch.Tensor):
345
+ def post_association(self, _: np.ndarray, detections: byotrack.Detections, active_mask: torch.Tensor):
346
346
  if self.active_states is None or self.kalman_filter is None or self.projections is None:
347
347
  raise RuntimeError("The linker should already be initialized.")
348
348
 
349
- # Update handlers
350
- links, active_mask, unmatched = self.update_active_tracks(links, detections)
351
-
352
349
  # Update the state of associated tracks (unassociated tracks keep the predicted state)
353
- self.active_states[links[:, 0]] = self.kalman_filter.update(
354
- self.active_states[links[:, 0]],
355
- detections.position[links[:, 1]][..., None],
356
- projection=self.projections[links[:, 0]],
350
+ self.active_states[self._links[:, 0]] = self.kalman_filter.update(
351
+ self.active_states[self._links[:, 0]],
352
+ detections.position[self._links[:, 1]][..., None],
353
+ projection=self.projections[self._links[:, 0]],
357
354
  )
358
355
 
359
356
  # Create new states for unmatched measures
360
- unmatched_measures = detections.position[unmatched]
357
+ unmatched_measures = detections.position[self._unmatched_detections]
361
358
 
362
359
  # Build the initial states for tracks:
363
360
  # We initialize the position using the detection position and the measurement std as covariance.
@@ -38,7 +38,7 @@ class KOFTLinkerParameters(KalmanLinkerParameters):
38
38
  order-th derivative can change between two consecutive frames. A common rule of thumb is to use
39
39
  3 * process_std ~= max_t(| x^(order)(t) - x^(order)(t+1)|). It can be provided for each dimension).
40
40
  Default: 1.5 pixels / frame^order
41
- kalman_order (int): Order of the Kalman filter to use. 0 is not supported.
41
+ kalman_order (int): Order of the Kalman filter to use. 0 is for brownian motion (it predicts a 0 velocity)
42
42
  1 for directed brownian motion, 2 for accelerated brownian motions, etc...
43
43
  Default: 1
44
44
  n_valid (int): Number associated detections required to validate the track after its creation.
@@ -118,8 +118,6 @@ class KOFTLinkerParameters(KalmanLinkerParameters):
118
118
  self.extract_flows_on_detections = extract_flows_on_detections
119
119
  self.always_measure_velocity = always_measure_velocity
120
120
 
121
- assert self.kalman_order >= 1, "With KOFT, the velocity is measured and thus should be modeled."
122
-
123
121
  flow_std: Union[float, torch.Tensor] = 1.0
124
122
  extract_flows_on_detections: bool = False
125
123
  always_measure_velocity: bool = True
@@ -234,9 +232,13 @@ class KOFTLinker(KalmanLinker):
234
232
  self.specs.detection_std,
235
233
  self.specs.process_std,
236
234
  dim=detections.dim,
237
- order=self.specs.kalman_order,
235
+ order=self.specs.kalman_order + (self.specs.kalman_order == 0),
238
236
  approximate=True, # Approximate so that a flow precisely means the velocity modeled here
239
237
  )
238
+ if self.specs.kalman_order == 0:
239
+ # In order 0, we still model velocity, but we always predict it at 0
240
+ self.kalman_filter.process_matrix[detections.dim :, detections.dim :] = 0
241
+
240
242
  # Doubles the measurement space to measure velocity
241
243
  self.kalman_filter.measurement_matrix = torch.eye(detections.dim * 2, self.kalman_filter.state_dim)
242
244
  self.kalman_filter.measurement_noise = torch.eye(detections.dim * 2)
@@ -297,22 +299,19 @@ class KOFTLinker(KalmanLinker):
297
299
  cost = -projections[:, None].log_likelihood(detections.position[None, ..., None])
298
300
  return cost, -torch.log(torch.tensor(self.specs.association_threshold)).item()
299
301
 
300
- def post_association(self, _: np.ndarray, detections: byotrack.Detections, links: torch.Tensor):
302
+ def post_association(self, _: np.ndarray, detections: byotrack.Detections, active_mask: torch.Tensor):
301
303
  if self.active_states is None or self.kalman_filter is None or self.projections is None:
302
304
  raise RuntimeError("The linker should already be initialized.")
303
305
 
304
306
  self.last_detections = detections # Save detections (May be required)
305
307
 
306
- # Update handlers
307
- links, active_mask, unmatched = self.update_active_tracks(links, detections)
308
-
309
308
  # Update the state of associated tracks (unassociated tracks keep the predicted state)
310
- self.active_states[links[:, 0]] = self.kalman_filter.update(
311
- self.active_states[links[:, 0]],
312
- detections.position[links[:, 1]][..., None],
309
+ self.active_states[self._links[:, 0]] = self.kalman_filter.update(
310
+ self.active_states[self._links[:, 0]],
311
+ detections.position[self._links[:, 1]][..., None],
313
312
  projection=torch_kf.GaussianState(
314
- self.projections.mean[links[:, 0], : detections.dim],
315
- self.projections.covariance[links[:, 0], : detections.dim, : detections.dim],
313
+ self.projections.mean[self._links[:, 0], : detections.dim],
314
+ self.projections.covariance[self._links[:, 0], : detections.dim, : detections.dim],
316
315
  None, # /!\ inv(cov[:2,:2]) != inv(cov)[:2, :2] =>
317
316
  ),
318
317
  measurement_matrix=self.kalman_filter.measurement_matrix[: detections.dim],
@@ -320,7 +319,7 @@ class KOFTLinker(KalmanLinker):
320
319
  )
321
320
 
322
321
  # Create new states for unmatched measures
323
- unmatched_measures = detections.position[unmatched]
322
+ unmatched_measures = detections.position[self._unmatched_detections]
324
323
  self.n_initial = unmatched_measures.shape[0]
325
324
 
326
325
  # Build the initial states for tracks:
@@ -135,21 +135,20 @@ class NearestNeighborLinker(FrameByFrameLinker):
135
135
  self.specs.association_threshold,
136
136
  )
137
137
 
138
- def post_association(self, _: np.ndarray, detections: byotrack.Detections, links: torch.Tensor):
138
+ def post_association(self, _: np.ndarray, detections: byotrack.Detections, active_mask: torch.Tensor):
139
139
  if self.active_positions is None:
140
140
  self.active_positions = torch.empty((0, detections.position.shape[1]))
141
141
 
142
- # Update handlers
143
- links, active_mask, unmatched = self.update_active_tracks(links, detections)
144
-
145
142
  # Update tracks positions with detections
146
143
  # Optionally using an EMA to reduce detections noise
147
- self.active_positions[links[:, 0]] -= (1.0 - self.specs.ema) * (
148
- self.active_positions[links[:, 0]] - detections.position[links[:, 1]]
144
+ self.active_positions[self._links[:, 0]] -= (1.0 - self.specs.ema) * (
145
+ self.active_positions[self._links[:, 0]] - detections.position[self._links[:, 1]]
149
146
  )
150
147
 
151
148
  # Merge still active positions and new ones
152
- self.active_positions = torch.cat((self.active_positions[active_mask], detections.position[unmatched]))
149
+ self.active_positions = torch.cat(
150
+ (self.active_positions[active_mask], detections.position[self._unmatched_detections])
151
+ )
153
152
 
154
153
  self.all_positions.append(self.active_positions.clone())
155
154
 
@@ -123,7 +123,6 @@ class TrackMateLinker(byotrack.Linker): # pylint: disable=too-few-public-method
123
123
 
124
124
  Note:
125
125
  This implementation requires Fiji to be installed (https://imagej.net/downloads)
126
- And tifffile library (https://github.com/cgohlke/tifffile#quickstart)
127
126
 
128
127
  Note:
129
128
  In case of missed detections, positions are filled with nan. To fill nan with true values, use an Interpolator
@@ -144,7 +143,7 @@ class TrackMateLinker(byotrack.Linker): # pylint: disable=too-few-public-method
144
143
  Args:
145
144
  fiji_path (str | os.PathLike): Path to the fiji executable
146
145
  The executable can be found inside the installation folder of Fiji.
147
- Linux: Fiji.app/ImageJ-<os>.exe
146
+ Linux: Fiji.app/ImageJ-<os>
148
147
  Windows: Fiji.app/ImageJ-<os>.exe
149
148
  MacOs: Fiji.app/Contents/MacOs/ImageJ-<os>
150
149
 
@@ -8,6 +8,7 @@ In the future, we may wrap the Fiji plugin to access BIO metrics.
8
8
  import os
9
9
  import pathlib
10
10
  import platform
11
+ import re
11
12
  import shutil
12
13
  import subprocess
13
14
  import tempfile
@@ -16,6 +17,7 @@ import warnings
16
17
 
17
18
 
18
19
  import byotrack
20
+ from byotrack import fiji
19
21
  from byotrack.dataset import ctc
20
22
  from byotrack.utils import sorted_alphanumeric
21
23
 
@@ -39,6 +41,7 @@ class CTCSoftwareRunner: # pylint: disable=too-few-public-methods
39
41
  ctc_software (pathlib.Path): Path to the ctc software folder.
40
42
  It should be the root folder containing Win/Linux/Mac subfolders with their executables.
41
43
  system (str): The user system in the CTC nomenclature. One of Linux, Mac or Win.
44
+ last_log (str): Logs of the last computed metrics
42
45
 
43
46
  """
44
47
 
@@ -46,6 +49,7 @@ class CTCSoftwareRunner: # pylint: disable=too-few-public-methods
46
49
 
47
50
  def __init__(self, ctc_software: Union[str, os.PathLike]):
48
51
  self.ctc_software = pathlib.Path(ctc_software)
52
+ self.last_log = ""
49
53
 
50
54
  self.system = platform.system()
51
55
  if self.system == "Darwin":
@@ -99,6 +103,8 @@ class CTCSoftwareRunner: # pylint: disable=too-few-public-methods
99
103
  "Cannot parse outputs, the CTC software probably found an error: " + output_string
100
104
  ) from exc
101
105
 
106
+ self.last_log = (pathlib.Path(dataset) / f"{seq:02}_RES" / f"{metric}_log.txt").read_text("utf-8")
107
+
102
108
  return value
103
109
 
104
110
 
@@ -129,10 +135,6 @@ class CTCMetrics(CTCSoftwareRunner):
129
135
 
130
136
  """
131
137
 
132
- def __init__(self, ctc_software):
133
- super().__init__(ctc_software)
134
- self.last_log = ""
135
-
136
138
  def compute_tracking_metric(
137
139
  self,
138
140
  metric: str,
@@ -191,8 +193,6 @@ class CTCMetrics(CTCSoftwareRunner):
191
193
  else:
192
194
  raise ValueError("Without any detections, `shape` must be provided")
193
195
 
194
- if detections_sequence:
195
- shape = detections_sequence[0].shape
196
196
  with tempfile.TemporaryDirectory(prefix="ByoTrack-CTC-Metrics") as output_dir:
197
197
  output_path = pathlib.Path(output_dir)
198
198
  ground_truth_path = output_path / "01_GT" / ("SEG" if metric == "SEG" else "TRA")
@@ -216,10 +216,7 @@ class CTCMetrics(CTCSoftwareRunner):
216
216
  **kwargs,
217
217
  )
218
218
 
219
- results = self.run(metric, output_path, 1)
220
- self.last_log = (output_path / "01_RES" / f"{metric}_log.txt").read_text("utf-8")
221
-
222
- return results
219
+ return self.run(metric, output_path, 1)
223
220
 
224
221
  def compute_detection_metric(
225
222
  self,
@@ -268,10 +265,7 @@ class CTCMetrics(CTCSoftwareRunner):
268
265
  ground_truth_path, ground_truth_detections_sequence, as_res=False, as_seg=metric == "SEG"
269
266
  )
270
267
 
271
- results = self.run(metric, output_path, 1)
272
- self.last_log = (output_path / "01_RES" / f"{metric}_log.txt").read_text("utf-8")
273
-
274
- return results
268
+ return self.run(metric, output_path, 1)
275
269
 
276
270
  @staticmethod
277
271
  def copy_ground_truth(
@@ -351,3 +345,147 @@ class CTCMetrics(CTCSoftwareRunner):
351
345
  for path in res_tiff_paths if is_res else gt_track_tiff_paths:
352
346
  frame_id = int(path.stem[4 if is_res else 9 :])
353
347
  shutil.copy(path, target_path / f"man_{'seg' if as_seg else 'track'}{frame_id:04}.tif")
348
+
349
+
350
+ class BioMetrics:
351
+ """Wrapper around the CTC "Biological measures" Fiji plugin [10]
352
+
353
+ It allows the computations of BIO metrics for the Cell Tracking Challenge [10].
354
+
355
+ Note:
356
+ This implementation requires Fiji to be installed (https://imagej.net/downloads)
357
+ with the CTC plugins (https://github.com/CellTrackingChallenge/fiji-plugins)
358
+
359
+ Attributes:
360
+ runner (byotrack.fiji.FijiRunner): Fiji runner
361
+ last_metrics (Tuple[float, float, float, float]): Sub metrics information for the last computed BIO
362
+ It consists of ("CT", "TF", "BC(i)", "CCA"). Their average gives the BIO metric.
363
+
364
+ """
365
+
366
+ plugin_name = "Biological measures"
367
+ # Reg exp for parsing outputs: find the last 4 numbers in the output line such in the format 'a, b, c, d]]'
368
+ output_regexp = re.compile(r"([+-]?\d*\.\d+), ([+-]?\d*\.\d+), ([+-]?\d*\.\d+), ([+-]?\d*\.\d+)]]$")
369
+
370
+ def __init__(self, fiji_path: Union[str, os.PathLike]) -> None:
371
+ """Constructor
372
+
373
+ Args:
374
+ fiji_path (str | os.PathLike): Path to the fiji executable
375
+ The executable can be found inside the installation folder of Fiji.
376
+ Linux: Fiji.app/ImageJ-<os>
377
+ Windows: Fiji.app/ImageJ-<os>.exe
378
+ MacOs: Fiji.app/Contents/MacOs/ImageJ-<os>
379
+
380
+ """
381
+ self.runner = fiji.FijiRunner(fiji_path, capture_outputs=True)
382
+ self.last_metrics = (0.0, 0.0, 0.0, 0.0)
383
+
384
+ def run(self, dataset: Union[str, os.PathLike], seq=1, n_digit=4) -> float:
385
+ """Run the CTC "Biological measures" plugin on the given dataset.
386
+
387
+ The dataset should already have results stored in it. It expects the CTC format.
388
+
389
+ Args:
390
+ metric (str): The metric to evaluate. One of (TRA, DET, SEG).
391
+ dataset (Union[str, os.PathLike]): Path to the dataset to evaluate.
392
+ seq (int): Sequence to evaluate inside the dataset.
393
+ The plugin will compare {dataset}/{seq:02}_RES with {dataset}/{seq:02}_GT
394
+ Default: 1
395
+ n_digit (int): Number of digits used to encode time in file names.
396
+ It is dataset dependant, but in ByoTrack, by default we use 4 digits.
397
+ Default: 4
398
+
399
+ Returns:
400
+ float: The evaluated metric
401
+
402
+ """
403
+ path = pathlib.Path(dataset)
404
+
405
+ # Let's do output redirection
406
+ self.runner.run(
407
+ "Biological measures", resPath=path / f"{seq:02}_RES", gtPath=path / f"{seq:02}_GT", noOfDigits=n_digit
408
+ )
409
+
410
+ line = self.runner.last_outputs.stdout.decode().strip().split("\n")[-1]
411
+ match = self.output_regexp.search(line)
412
+ if not match:
413
+ raise RuntimeError("Cannot parse outputs, the CTC software probably found an error: " + line)
414
+
415
+ self.last_metrics = tuple( # type: ignore
416
+ (float(group) if float(group) >= 0.0 else 0.0) for group in match.groups()
417
+ )
418
+
419
+ return sum(self.last_metrics) / 4
420
+
421
+ def compute(
422
+ self,
423
+ tracks: Collection[byotrack.Track],
424
+ ground_truth_tracks: Union[str, os.PathLike, Collection[byotrack.Track]],
425
+ *,
426
+ detections_sequence: Sequence[byotrack.Detections] = (),
427
+ ground_truth_detections_sequence: Sequence[byotrack.Detections] = (),
428
+ **kwargs,
429
+ ) -> float:
430
+ """Compute BIO metric for the given tracks and ground truthes.
431
+
432
+ It will create a temporary folder, to store the tracks and ground-truthes in the right format and then
433
+ execute the CTC plugins. The temporary folder is removed at the end.
434
+
435
+ See `run` to simply compute the metric of an already existing folder.
436
+
437
+ Note:
438
+ In CTC, matching of predicted tracks with GT ones is done based on a kind of IOU.
439
+ Therefore, it may be useful to provide the detections_sequence associated with the tracks.
440
+ You may also tweak the `default_radius` arguments (in kwargs). (See `ctc.save_tracks`)
441
+
442
+ Args:
443
+ tracks (Collection[byotrack.Track]): Predicted tracks to evaluate.
444
+ ground_truth_tracks (Union[str, os.PathLike, Collection[byotrack.Track]]): Ground truth data.
445
+ It is either a path to the GT tracks folder, which will be copied in our temporary folder.
446
+ Or it is a list of ByoTrack.Track, that will be saved in the temporary folder.
447
+ detections_sequence (Sequence[byotrack.Detections]): Optional detections, used when saving the tracks.
448
+ Default: () # No detections and tracks segmentations will be disk of radius `default_radius`.
449
+ ground_truth_detections_sequence (Sequence[byotrack.Detections]): Optional detections for ground-truth.
450
+ When saving the GT tracks, these detections are used. See `ctc.save_tracks`.
451
+ Default: () # No detections and tracks segmentations will be disk of radius `default_radius`.
452
+ **kwargs: Additional arguments to provide to the `ctc.save_tracks` function.
453
+ 'shape': Provide the shape of saved image. It is mandatory if no detections is provided.
454
+ 'default_radius': Radius of the disk drawn for tracks that have no detections.
455
+ 'last': Last frame to consider. (Typically to shorten the sequences,
456
+ or if no object is tracked on the last frames, this will enforce the creation of empty tiff files)
457
+
458
+ """
459
+ if "shape" in kwargs:
460
+ shape = kwargs.pop("shape")
461
+ else:
462
+ if detections_sequence:
463
+ shape = detections_sequence[0].shape
464
+ elif ground_truth_detections_sequence:
465
+ shape = ground_truth_detections_sequence[0].shape
466
+ else:
467
+ raise ValueError("Without any detections, `shape` must be provided")
468
+
469
+ with tempfile.TemporaryDirectory(prefix="ByoTrack-CTC-Metrics") as output_dir:
470
+ output_path = pathlib.Path(output_dir)
471
+ ground_truth_path = output_path / "01_GT" / "TRA"
472
+ ctc.save_tracks(
473
+ output_path / "01_RES", tracks, detections_sequence, as_res=True, shape=shape, n_digit=4, **kwargs
474
+ )
475
+ if isinstance(ground_truth_tracks, (str, os.PathLike)):
476
+ if ground_truth_detections_sequence:
477
+ warnings.warn(
478
+ "When using a saved GT folder, it will be copied and ground-truth detections are ignored"
479
+ )
480
+ CTCMetrics.copy_ground_truth(ground_truth_tracks, ground_truth_path)
481
+ else:
482
+ ctc.save_tracks(
483
+ ground_truth_path,
484
+ ground_truth_tracks,
485
+ ground_truth_detections_sequence,
486
+ as_res=False,
487
+ shape=shape,
488
+ **kwargs,
489
+ )
490
+
491
+ return self.run(output_path, 1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: byotrack
3
- Version: 1.2.0.dev2
3
+ Version: 1.2.0.dev3
4
4
  Summary: Biological particle tracking with Python
5
5
  Home-page: https://github.com/raphaelreme/byotrack
6
6
  Author: Raphael Reme
File without changes
File without changes
File without changes
File without changes