singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,764 @@
1
+ """
2
+ Motion-Aware Tracking Enhancements for SAM2.
3
+
4
+ Includes:
5
+ - Kalman filter-based motion prediction (SAMURAI-style)
6
+ - OC-SORT drift correction (virtual trajectory + ORU)
7
+ - Mask quality scoring and temporal consistency checks
8
+ """
9
+
10
+ import numpy as np
11
+ import torch
12
+ import logging
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class KalmanBoxTracker:
19
+ """
20
+ Kalman filter for tracking bounding boxes with OC-SORT enhancements.
21
+ State: [x_center, y_center, scale, aspect_ratio, dx, dy, d_scale]
22
+
23
+ OC-SORT features:
24
+ - Virtual trajectory: During occlusion, generate virtual observations from velocity
25
+ - ORU (Observation-Centric Re-Update): When object reappears, correct past states
26
+ """
27
+ def __init__(self, bbox, delta_t=1, inertia=0.2):
28
+ """
29
+ Initialize tracker with bounding box [x1, y1, x2, y2].
30
+
31
+ Args:
32
+ bbox: Initial bounding box [x1, y1, x2, y2]
33
+ delta_t: Frame interval for virtual trajectory (OC-SORT paper: 1)
34
+ inertia: Velocity smoothing factor for ORU (OC-SORT paper: 0.2)
35
+ """
36
+ from filterpy.kalman import KalmanFilter
37
+
38
+ self.kf = KalmanFilter(dim_x=7, dim_z=4)
39
+ self.delta_t = delta_t
40
+ self.inertia = inertia
41
+
42
+ # State transition matrix (constant velocity model)
43
+ self.kf.F = np.array([
44
+ [1, 0, 0, 0, 1, 0, 0],
45
+ [0, 1, 0, 0, 0, 1, 0],
46
+ [0, 0, 1, 0, 0, 0, 1],
47
+ [0, 0, 0, 1, 0, 0, 0],
48
+ [0, 0, 0, 0, 1, 0, 0],
49
+ [0, 0, 0, 0, 0, 1, 0],
50
+ [0, 0, 0, 0, 0, 0, 1]
51
+ ], dtype=np.float32)
52
+
53
+ # Measurement matrix
54
+ self.kf.H = np.array([
55
+ [1, 0, 0, 0, 0, 0, 0],
56
+ [0, 1, 0, 0, 0, 0, 0],
57
+ [0, 0, 1, 0, 0, 0, 0],
58
+ [0, 0, 0, 1, 0, 0, 0]
59
+ ], dtype=np.float32)
60
+
61
+ # Measurement noise covariance
62
+ self.kf.R *= 10.0
63
+ self.kf.R[2, 2] *= 10.0 # Scale measurement has higher noise
64
+
65
+ # Process noise covariance
66
+ self.kf.P[4:, 4:] *= 1000.0 # High uncertainty in velocities initially
67
+ self.kf.P *= 10.0
68
+
69
+ self.kf.Q[-1, -1] *= 0.01
70
+ self.kf.Q[4:, 4:] *= 0.01
71
+
72
+ self.kf.x[:4] = self._bbox_to_z(bbox).reshape(4, 1)
73
+
74
+ self.time_since_update = 0
75
+ self.history = []
76
+ self.hits = 0
77
+ self.age = 0
78
+
79
+ # OC-SORT: Store observations for ORU
80
+ self.observations = {} # frame_idx -> observation (z)
81
+ self.last_observation_frame = 0
82
+ self.last_observation = self._bbox_to_z(bbox)
83
+ self.frozen_velocity = None # Velocity frozen at occlusion start
84
+
85
+ def _bbox_to_z(self, bbox):
86
+ """Convert [x1, y1, x2, y2] to [x_center, y_center, scale, aspect_ratio]."""
87
+ x1, y1, x2, y2 = bbox
88
+ w = x2 - x1
89
+ h = y2 - y1
90
+ x_c = x1 + w / 2.0
91
+ y_c = y1 + h / 2.0
92
+ s = w * h # scale (area)
93
+ r = w / (h + 1e-6) # aspect ratio
94
+ return np.array([x_c, y_c, s, r], dtype=np.float32)
95
+
96
+ def _z_to_bbox(self, z):
97
+ """Convert [x_center, y_center, scale, aspect_ratio] to [x1, y1, x2, y2]."""
98
+ x_c, y_c, s, r = z.flatten()[:4]
99
+ w = np.sqrt(max(s * r, 1.0))
100
+ h = s / (w + 1e-6)
101
+ return np.array([x_c - w/2, y_c - h/2, x_c + w/2, y_c + h/2], dtype=np.float32)
102
+
103
+ def predict(self):
104
+ """Predict next state and return predicted bbox."""
105
+ if self.kf.x[6] + self.kf.x[2] <= 0:
106
+ self.kf.x[6] = 0.0
107
+ self.kf.predict()
108
+ self.age += 1
109
+ self.time_since_update += 1
110
+ self.history.append(self._z_to_bbox(self.kf.x))
111
+ return self.history[-1]
112
+
113
+ def predict_with_virtual_trajectory(self, frame_idx):
114
+ """
115
+ OC-SORT: Predict with virtual trajectory during occlusion.
116
+ Uses last known velocity to generate virtual observations.
117
+
118
+ Args:
119
+ frame_idx: Current frame index
120
+
121
+ Returns:
122
+ Predicted bbox
123
+ """
124
+ if self.time_since_update == 0:
125
+ # Not occluded, use normal prediction
126
+ return self.predict()
127
+
128
+ # Freeze velocity at start of occlusion
129
+ if self.frozen_velocity is None:
130
+ self.frozen_velocity = self.kf.x[4:7].copy()
131
+
132
+ # Generate virtual observation using frozen velocity
133
+ virtual_z = self.last_observation.copy()
134
+ dt = frame_idx - self.last_observation_frame
135
+ if dt > 0:
136
+ virtual_z[0] += self.frozen_velocity[0, 0] * dt # x_center
137
+ virtual_z[1] += self.frozen_velocity[1, 0] * dt # y_center
138
+ virtual_z[2] += self.frozen_velocity[2, 0] * dt # scale
139
+ virtual_z[2] = max(virtual_z[2], 1.0) # scale must stay positive
140
+
141
+ # Fold the virtual observation into the Kalman state so predicted and
142
+ # observed trajectories stay in sync during occlusions.
143
+ self.kf.predict()
144
+ self.kf.update(virtual_z)
145
+ self.age += 1
146
+ self.time_since_update += 1
147
+
148
+ bbox = self._z_to_bbox(self.kf.x)
149
+ self.history.append(bbox)
150
+ return bbox
151
+
152
+ def update(self, bbox, frame_idx=None):
153
+ """Update state with observed bbox."""
154
+ z = self._bbox_to_z(bbox)
155
+
156
+ # OC-SORT: Apply ORU if recovering from occlusion
157
+ if self.time_since_update > 0 and frame_idx is not None:
158
+ self._apply_oru(z, frame_idx)
159
+
160
+ self.time_since_update = 0
161
+ self.history = []
162
+ self.hits += 1
163
+ self.kf.update(z)
164
+
165
+ # Store observation for future ORU
166
+ if frame_idx is not None:
167
+ self.observations[frame_idx] = z.copy()
168
+ self.last_observation_frame = frame_idx
169
+ self.last_observation = z.copy()
170
+ self.frozen_velocity = None # Reset frozen velocity
171
+
172
+ def _apply_oru(self, new_z, frame_idx):
173
+ """
174
+ OC-SORT: Observation-Centric Re-Update.
175
+ When object reappears after occlusion, interpolate backwards
176
+ to correct past Kalman state drift.
177
+
178
+ Args:
179
+ new_z: New observation [x_c, y_c, s, r]
180
+ frame_idx: Current frame index
181
+ """
182
+ if self.time_since_update <= 1:
183
+ return # ORU only matters after more than one missed frame
184
+
185
+ dt = frame_idx - self.last_observation_frame
186
+ if dt <= 0:
187
+ return
188
+
189
+ velocity = (new_z - self.last_observation) / dt
190
+
191
+ # Apply velocity smoothing with inertia
192
+ # Mix new velocity with old frozen velocity
193
+ if self.frozen_velocity is not None:
194
+ old_vel = self.frozen_velocity.flatten()[:3]
195
+ new_vel = velocity[:3]
196
+ smoothed_vel = self.inertia * old_vel + (1 - self.inertia) * new_vel
197
+ else:
198
+ smoothed_vel = velocity[:3]
199
+
200
+ # Re-update state with corrected velocity
201
+ self.kf.x[4, 0] = smoothed_vel[0] # dx
202
+ self.kf.x[5, 0] = smoothed_vel[1] # dy
203
+ self.kf.x[6, 0] = smoothed_vel[2] # d_scale
204
+
205
+ # Reduce uncertainty since we now have a good observation
206
+ self.kf.P[4:, 4:] *= 0.5
207
+
208
+ def get_state(self):
209
+ """Get current bbox estimate."""
210
+ return self._z_to_bbox(self.kf.x)
211
+
212
+ def get_velocity(self):
213
+ """Get current velocity estimate [dx, dy, d_scale]."""
214
+ return self.kf.x[4:7].flatten().copy()
215
+
216
+
217
+ def mask_to_bbox(mask):
218
+ """Convert binary mask to bounding box [x1, y1, x2, y2]."""
219
+ if mask is None or mask.max() == 0:
220
+ return None
221
+ ys, xs = np.where(mask > 0)
222
+ if len(ys) == 0:
223
+ return None
224
+ # Use [x1, y1, x2, y2) convention (exclusive max corner) for consistent area math.
225
+ return np.array([xs.min(), ys.min(), xs.max() + 1, ys.max() + 1], dtype=np.float32)
226
+
227
+
228
+ def compute_iou(bbox1, bbox2):
229
+ """Compute IoU between two bboxes [x1, y1, x2, y2]."""
230
+ if bbox1 is None or bbox2 is None:
231
+ return 0.0
232
+
233
+ x1 = max(bbox1[0], bbox2[0])
234
+ y1 = max(bbox1[1], bbox2[1])
235
+ x2 = min(bbox1[2], bbox2[2])
236
+ y2 = min(bbox1[3], bbox2[3])
237
+
238
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
239
+
240
+ area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
241
+ area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
242
+
243
+ union_area = area1 + area2 - inter_area
244
+
245
+ if union_area <= 0:
246
+ return 0.0
247
+
248
+ return inter_area / union_area
249
+
250
+
251
+ def compute_mask_score(mask_logit, predicted_bbox=None, actual_bbox=None,
252
+ use_multiplicative=True):
253
+ """
254
+ Compute quality score for a mask prediction (SAMURAI-style).
255
+
256
+ SAMURAI formula: score = sigmoid(obj_score) * sigmoid(IoU)
257
+ This multiplicative approach is stricter - if either component is low,
258
+ the whole score drops significantly.
259
+
260
+ Args:
261
+ mask_logit: Raw mask logits from SAM2
262
+ predicted_bbox: Predicted bbox from Kalman filter
263
+ actual_bbox: Actual bbox from mask
264
+ use_multiplicative: If True, use SAMURAI's multiplicative scoring.
265
+ If False, use additive (average).
266
+
267
+ Returns:
268
+ score: Quality score between 0 and 1
269
+ obj_score: Objectness component
270
+ motion_iou: Motion consistency component
271
+ """
272
+ # Objectness score: max confidence in the mask (already sigmoid for logits)
273
+ if isinstance(mask_logit, torch.Tensor):
274
+ obj_score = torch.sigmoid(mask_logit).max().item()
275
+ else:
276
+ obj_score = 1.0 / (1.0 + np.exp(-mask_logit.max()))
277
+
278
+ # Motion consistency score (IoU between predicted and actual bbox)
279
+ if predicted_bbox is not None and actual_bbox is not None:
280
+ motion_iou = compute_iou(predicted_bbox, actual_bbox)
281
+ else:
282
+ motion_iou = 1.0 # No motion prediction available
283
+
284
+ # SAMURAI applies sigmoid to IoU as well for normalization
285
+ # sigmoid(IoU) maps [0,1] -> [0.5, 0.73] roughly, so we use a scaled sigmoid
286
+ # that maps 0->0, 0.5->0.5, 1->1 for better range
287
+ def soft_sigmoid(x, k=5.0):
288
+ """Soft sigmoid that maps [0,1] to [0,1] with 0.5->0.5"""
289
+ return 1.0 / (1.0 + np.exp(-k * (x - 0.5)))
290
+
291
+ if use_multiplicative:
292
+ # SAMURAI-style: multiplicative scoring (stricter)
293
+ # Both components must be high for good score
294
+ normalized_iou = soft_sigmoid(motion_iou)
295
+ score = obj_score * normalized_iou
296
+ else:
297
+ # Additive scoring (more lenient)
298
+ score = 0.5 * obj_score + 0.5 * motion_iou
299
+
300
+ return score, obj_score, motion_iou
301
+
302
+
303
+ class AppearanceMemoryBank:
304
+ """
305
+ Long-term appearance memory that stores high-quality "golden" masks per object.
306
+ When an object is lost for many frames (long occlusion) and starts reappearing,
307
+ the best stored mask is used to re-seed SAM2 so appearance is recovered.
308
+ """
309
+ def __init__(self, max_snapshots=5, min_score_to_store=0.6,
310
+ occlusion_enter_frames=10, recovery_area_ratio=0.05,
311
+ reseed_debounce_frames=10, shape_area_ratio_range=(0.7, 1.3),
312
+ max_aspect_ratio_change=0.35):
313
+ """
314
+ Args:
315
+ max_snapshots: max golden masks kept per object (best-scoring kept)
316
+ min_score_to_store: minimum quality score to consider a frame "golden"
317
+ occlusion_enter_frames: consecutive zero/low frames before entering occlusion state
318
+ recovery_area_ratio: when current area / golden area >= this, consider partial
319
+ recovery. Set very low (0.05 = 5%) so even a sliver of tail triggers reseed.
320
+ reseed_debounce_frames: min frames between consecutive re-seeds for same object
321
+ shape_area_ratio_range: accepted candidate area ratio range vs learned shape prior
322
+ max_aspect_ratio_change: max relative aspect-ratio change vs learned shape prior
323
+ """
324
+ self.max_snapshots = max_snapshots
325
+ self.min_score_to_store = min_score_to_store
326
+ self.occlusion_enter_frames = occlusion_enter_frames
327
+ self.recovery_area_ratio = recovery_area_ratio
328
+ self.reseed_debounce_frames = reseed_debounce_frames
329
+ self.shape_area_ratio_range = shape_area_ratio_range
330
+ self.max_aspect_ratio_change = max_aspect_ratio_change
331
+
332
+ # Per-object storage
333
+ self.snapshots = {} # obj_id -> [(score, frame_idx, mask, bbox, area)]
334
+ self.golden_area = {} # obj_id -> median area from golden masks
335
+ self.golden_aspect = {} # obj_id -> median aspect ratio from golden masks
336
+ self.zero_streak = {} # obj_id -> consecutive frames with empty/tiny mask
337
+ self.occluded = {} # obj_id -> bool
338
+ self.recovery_pending = {} # obj_id -> bool (just left occlusion, needs re-seed)
339
+ self.last_reseed_frame = {} # obj_id -> frame_idx of last re-seed (debounce)
340
+
341
+ @staticmethod
342
+ def _bbox_aspect_ratio(bbox):
343
+ if bbox is None:
344
+ return None
345
+ w = max(float(bbox[2] - bbox[0]), 1.0)
346
+ h = max(float(bbox[3] - bbox[1]), 1.0)
347
+ return w / h
348
+
349
+ def _passes_shape_guard(self, obj_id, area, bbox):
350
+ """Keep golden masks close to the learned prompt-shape profile."""
351
+ ref_area = self.golden_area.get(obj_id)
352
+ if ref_area is not None and ref_area > 0:
353
+ ratio = float(area) / float(ref_area)
354
+ lo, hi = self.shape_area_ratio_range
355
+ if ratio < lo or ratio > hi:
356
+ return False
357
+
358
+ ref_aspect = self.golden_aspect.get(obj_id)
359
+ cur_aspect = self._bbox_aspect_ratio(bbox)
360
+ if ref_aspect is not None and ref_aspect > 0 and cur_aspect is not None:
361
+ rel_change = abs(cur_aspect - ref_aspect) / ref_aspect
362
+ if rel_change > self.max_aspect_ratio_change:
363
+ return False
364
+ return True
365
+
366
+ def store_if_golden(self, obj_id, mask, bbox, area, score, frame_idx):
367
+ """Store mask snapshot if quality is high enough."""
368
+ if score < self.min_score_to_store or mask is None or area < 1:
369
+ return
370
+ if not self._passes_shape_guard(obj_id, area, bbox):
371
+ return
372
+ if obj_id not in self.snapshots:
373
+ self.snapshots[obj_id] = []
374
+
375
+ entry = (score, frame_idx, mask.copy(), bbox.copy() if bbox is not None else None, area)
376
+ snaps = self.snapshots[obj_id]
377
+ snaps.append(entry)
378
+ snaps.sort(key=lambda e: e[0], reverse=True)
379
+ if len(snaps) > self.max_snapshots:
380
+ snaps[:] = snaps[:self.max_snapshots]
381
+
382
+ # Refresh the reference ("golden") area as the median across snapshots.
383
+ areas = [s[4] for s in snaps]
384
+ self.golden_area[obj_id] = float(np.median(areas))
385
+ aspects = [self._bbox_aspect_ratio(s[3]) for s in snaps if s[3] is not None]
386
+ if aspects:
387
+ self.golden_aspect[obj_id] = float(np.median(aspects))
388
+
389
+ def update_occlusion_state(self, obj_id, mask_area, frame_idx):
390
+ """
391
+ Track whether object is in long occlusion and detect recovery.
392
+ Returns True if a re-seed should happen this frame.
393
+ """
394
+ golden = self.golden_area.get(obj_id)
395
+ if golden is None or golden < 1:
396
+ return False
397
+
398
+ area_fraction = mask_area / golden
399
+
400
+ # Object is mostly gone?
401
+ if area_fraction < 0.1:
402
+ self.zero_streak[obj_id] = self.zero_streak.get(obj_id, 0) + 1
403
+ else:
404
+ self.zero_streak[obj_id] = 0
405
+
406
+ was_occluded = self.occluded.get(obj_id, False)
407
+
408
+ # Enter occlusion mode after sustained absence
409
+ if self.zero_streak.get(obj_id, 0) >= self.occlusion_enter_frames:
410
+ self.occluded[obj_id] = True
411
+
412
+ # Detect recovery: was occluded, now partial mask is back
413
+ if was_occluded and area_fraction >= self.recovery_area_ratio:
414
+ self.occluded[obj_id] = False
415
+ self.zero_streak[obj_id] = 0
416
+
417
+ # Debounce: don't re-seed too frequently
418
+ last = self.last_reseed_frame.get(obj_id, -999)
419
+ if frame_idx - last >= self.reseed_debounce_frames:
420
+ self.recovery_pending[obj_id] = True
421
+ self.last_reseed_frame[obj_id] = frame_idx
422
+ return True
423
+ return False
424
+
425
+ def pop_reseed_mask(self, obj_id):
426
+ """
427
+ Return the best golden mask for re-seeding, or None.
428
+ Clears the pending flag.
429
+ """
430
+ self.recovery_pending.pop(obj_id, None)
431
+ snaps = self.snapshots.get(obj_id)
432
+ if not snaps:
433
+ return None
434
+ return snaps[0][2] # highest-score snapshot mask array
435
+
436
+ def is_recovery_pending(self, obj_id):
437
+ return self.recovery_pending.get(obj_id, False)
438
+
439
+ def has_snapshots(self, obj_id):
440
+ return bool(self.snapshots.get(obj_id))
441
+
442
+
443
+ class MultiObjectMotionTracker:
444
+ """
445
+ Manages Kalman filter trackers for multiple objects with
446
+ motion-aware tracking, drift detection, and automatic prompt injection.
447
+
448
+ Supports OC-SORT enhancements:
449
+ - Virtual trajectory during occlusions
450
+ - ORU (Observation-Centric Re-Update) for drift correction
451
+ """
452
+ def __init__(self, motion_score_threshold=0.3, use_kalman=True,
453
+ consecutive_low_threshold=3, area_change_threshold=0.5,
454
+ use_ocsort=False, ocsort_inertia=0.2, max_history_frames=1000,
455
+ adaptive_threshold=True, threshold_window=30, hysteresis_margin=0.05,
456
+ max_correction_jump_px=80.0, max_correction_area_ratio=2.5,
457
+ enable_appearance_memory=True, appearance_min_score=0.6,
458
+ appearance_max_snapshots=5, occlusion_enter_frames=5,
459
+ recovery_area_ratio=0.15, reseed_debounce_frames=5,
460
+ shape_area_ratio_range=(0.7, 1.3), max_aspect_ratio_change=0.35):
461
+ self.trackers = {} # obj_id -> KalmanBoxTracker
462
+ self.scores = {} # obj_id -> {frame_idx: score}
463
+ self.bboxes = {} # obj_id -> {frame_idx: bbox} for temporal consistency
464
+ self.areas = {} # obj_id -> {frame_idx: area}
465
+ self.motion_score_threshold = motion_score_threshold
466
+ self.use_kalman = use_kalman
467
+ self.kalman_available = True
468
+ self.consecutive_low_threshold = consecutive_low_threshold # Frames before recovery
469
+ self.area_change_threshold = area_change_threshold # Max allowed area change ratio
470
+ self.consecutive_low_count = {} # obj_id -> count of consecutive low scores
471
+ self.needs_correction = {} # obj_id -> bool
472
+
473
+ # OC-SORT parameters
474
+ self.use_ocsort = use_ocsort
475
+ self.ocsort_inertia = ocsort_inertia # Velocity smoothing (paper: 0.2)
476
+ self.max_history_frames = max_history_frames
477
+ self.adaptive_threshold = adaptive_threshold
478
+ self.threshold_window = threshold_window
479
+ self.hysteresis_margin = hysteresis_margin
480
+ self.max_correction_jump_px = max_correction_jump_px
481
+ self.max_correction_area_ratio = max_correction_area_ratio
482
+ self.low_state = {} # obj_id -> bool (hysteresis state)
483
+
484
+ # Long-term appearance memory for occlusion recovery
485
+ self.enable_appearance_memory = enable_appearance_memory
486
+ self.appearance_memory = AppearanceMemoryBank(
487
+ max_snapshots=appearance_max_snapshots,
488
+ min_score_to_store=appearance_min_score,
489
+ occlusion_enter_frames=occlusion_enter_frames,
490
+ recovery_area_ratio=recovery_area_ratio,
491
+ reseed_debounce_frames=reseed_debounce_frames,
492
+ shape_area_ratio_range=shape_area_ratio_range,
493
+ max_aspect_ratio_change=max_aspect_ratio_change,
494
+ ) if enable_appearance_memory else None
495
+
496
+ try:
497
+ from filterpy.kalman import KalmanFilter
498
+ except ImportError:
499
+ self.kalman_available = False
500
+ self.use_kalman = False
501
+
502
+ def initialize_tracker(self, obj_id, bbox, frame_idx=0):
503
+ """Initialize a new tracker for an object."""
504
+ if not self.use_kalman or not self.kalman_available:
505
+ return
506
+ if bbox is None:
507
+ return
508
+ try:
509
+ self.trackers[obj_id] = KalmanBoxTracker(
510
+ bbox,
511
+ delta_t=1,
512
+ inertia=self.ocsort_inertia
513
+ )
514
+ self.trackers[obj_id].last_observation_frame = frame_idx
515
+ self.scores[obj_id] = {}
516
+ self.bboxes[obj_id] = {}
517
+ self.areas[obj_id] = {}
518
+ self.consecutive_low_count[obj_id] = 0
519
+ self.needs_correction[obj_id] = False
520
+ self.low_state[obj_id] = False
521
+ except Exception as e:
522
+ logger.debug("Failed to initialize tracker for obj %s: %s", obj_id, e)
523
+
524
+ def predict(self, obj_id):
525
+ """Return the current Kalman state estimate without advancing it."""
526
+ if obj_id not in self.trackers:
527
+ return None
528
+ try:
529
+ return self.trackers[obj_id].get_state()
530
+ except Exception as e:
531
+ logger.debug("Failed to get state for obj %s: %s", obj_id, e)
532
+ return None
533
+
534
+ def predict_and_advance(self, obj_id, frame_idx=None):
535
+ """Advance the Kalman state by one step and return the predicted bbox.
536
+
537
+ When OC-SORT is enabled and a frame index is provided, the tracker
538
+ uses a virtual trajectory so occlusions do not zero out velocity.
539
+ """
540
+ if obj_id not in self.trackers:
541
+ return None
542
+ try:
543
+ if self.use_ocsort and frame_idx is not None:
544
+ return self.trackers[obj_id].predict_with_virtual_trajectory(frame_idx)
545
+ else:
546
+ return self.trackers[obj_id].predict()
547
+ except Exception as e:
548
+ logger.debug("Failed to predict tracker for obj %s: %s", obj_id, e)
549
+ return None
550
+
551
+ def get_predicted_bbox_for_correction(self, obj_id):
552
+ """Return the Kalman bbox as an [x1, y1, x2, y2] box prompt."""
553
+ if obj_id not in self.trackers:
554
+ return None
555
+ try:
556
+ bbox = self.trackers[obj_id].get_state()
557
+ if bbox is None:
558
+ return None
559
+ return bbox.tolist()
560
+ except Exception as e:
561
+ logger.debug("Failed to get correction bbox for obj %s: %s", obj_id, e)
562
+ return None
563
+
564
+ def get_frame_score(self, obj_id, frame_idx):
565
+ if obj_id not in self.scores:
566
+ return None
567
+ return self.scores[obj_id].get(frame_idx)
568
+
569
+ def get_effective_threshold(self, obj_id):
570
+ base = self.motion_score_threshold
571
+ if not self.adaptive_threshold or obj_id not in self.scores:
572
+ return base
573
+ frames = sorted(self.scores[obj_id].keys())[-self.threshold_window:]
574
+ if len(frames) < 5:
575
+ return base
576
+ vals = np.array([self.scores[obj_id][f] for f in frames], dtype=np.float32)
577
+ mean = float(vals.mean())
578
+ std = float(vals.std())
579
+ adaptive = mean - 0.5 * std
580
+ lower = max(0.05, base - 0.15)
581
+ upper = min(0.95, base + 0.15)
582
+ return float(np.clip(adaptive, lower, upper))
583
+
584
+ def is_correction_bbox_sane(self, obj_id, pred_bbox):
585
+ if pred_bbox is None or obj_id not in self.bboxes or not self.bboxes[obj_id]:
586
+ return False
587
+ last_frame = max(self.bboxes[obj_id].keys())
588
+ last_bbox = self.bboxes[obj_id][last_frame]
589
+ if last_bbox is None:
590
+ return False
591
+
592
+ pred_bbox = np.asarray(pred_bbox, dtype=np.float32)
593
+ last_bbox = np.asarray(last_bbox, dtype=np.float32)
594
+
595
+ pred_center = np.array(
596
+ [(pred_bbox[0] + pred_bbox[2]) * 0.5, (pred_bbox[1] + pred_bbox[3]) * 0.5],
597
+ dtype=np.float32,
598
+ )
599
+ last_center = np.array(
600
+ [(last_bbox[0] + last_bbox[2]) * 0.5, (last_bbox[1] + last_bbox[3]) * 0.5],
601
+ dtype=np.float32,
602
+ )
603
+ center_shift = float(np.linalg.norm(pred_center - last_center))
604
+ if center_shift > self.max_correction_jump_px:
605
+ return False
606
+
607
+ pred_area = max((pred_bbox[2] - pred_bbox[0]) * (pred_bbox[3] - pred_bbox[1]), 1.0)
608
+ last_area = max((last_bbox[2] - last_bbox[0]) * (last_bbox[3] - last_bbox[1]), 1.0)
609
+ area_ratio = max(pred_area / last_area, last_area / pred_area)
610
+ return area_ratio <= self.max_correction_area_ratio
611
+
612
+ def update(self, obj_id, mask, mask_logit, frame_idx):
613
+ """
614
+ Update tracker with observed mask and compute quality score.
615
+
616
+ Returns:
617
+ score: Quality score for this frame
618
+ should_use_for_memory: Whether this frame should be used for memory
619
+ """
620
+ actual_bbox = mask_to_bbox(mask)
621
+
622
+ if actual_bbox is None:
623
+ if obj_id in self.consecutive_low_count:
624
+ self.consecutive_low_count[obj_id] += 1
625
+ # OC-SORT: Still predict with virtual trajectory to maintain state
626
+ if self.use_ocsort and obj_id in self.trackers:
627
+ try:
628
+ self.trackers[obj_id].predict_with_virtual_trajectory(frame_idx)
629
+ except Exception as e:
630
+ logger.debug(
631
+ "Failed virtual trajectory for obj %s at frame %s: %s",
632
+ obj_id,
633
+ frame_idx,
634
+ e,
635
+ )
636
+ return 0.0, False
637
+
638
+ mask_area = float(np.sum(mask > 0))
639
+
640
+ if obj_id not in self.trackers:
641
+ self.initialize_tracker(obj_id, actual_bbox, frame_idx)
642
+ if obj_id not in self.scores:
643
+ self.scores[obj_id] = {}
644
+ self.scores[obj_id][frame_idx] = 1.0
645
+ self.bboxes[obj_id] = {frame_idx: actual_bbox}
646
+ self.areas[obj_id] = {frame_idx: mask_area}
647
+ self.low_state[obj_id] = False
648
+ return 1.0, True
649
+
650
+ # OC-SORT virtual trajectory: advance Kalman state through occlusions
651
+ # using frame_idx as the time index.
652
+ predicted_bbox = self.predict_and_advance(obj_id, frame_idx) if self.use_kalman else None
653
+
654
+ score, obj_score, motion_iou = compute_mask_score(
655
+ mask_logit, predicted_bbox, actual_bbox
656
+ )
657
+
658
+ # Temporal consistency: large area jumps usually indicate mask drift,
659
+ # so penalise the score when the current area deviates from the
660
+ # running mean over the last five frames.
661
+ if obj_id in self.areas and self.areas[obj_id]:
662
+ recent_frames = sorted(self.areas[obj_id].keys())[-5:]
663
+ if recent_frames:
664
+ avg_area = np.mean([self.areas[obj_id][f] for f in recent_frames])
665
+ if avg_area > 0:
666
+ area_ratio = mask_area / avg_area
667
+ if area_ratio < (1 - self.area_change_threshold) or \
668
+ area_ratio > (1 + self.area_change_threshold):
669
+ score *= 0.5
670
+
671
+ # OC-SORT Observation-Centric Re-Update: feed the real observation back
672
+ # into the tracker keyed on frame_idx.
673
+ if self.use_kalman and obj_id in self.trackers:
674
+ try:
675
+ self.trackers[obj_id].update(actual_bbox, frame_idx)
676
+ except Exception as e:
677
+ logger.debug("Failed to update tracker for obj %s at frame %s: %s", obj_id, frame_idx, e)
678
+
679
+ # Store data
680
+ if obj_id not in self.scores:
681
+ self.scores[obj_id] = {}
682
+ self.scores[obj_id][frame_idx] = score
683
+ self.bboxes[obj_id][frame_idx] = actual_bbox
684
+ self.areas[obj_id][frame_idx] = mask_area
685
+ self._prune_history(obj_id)
686
+
687
+ # Adaptive threshold + hysteresis to reduce flicker in keep/drop decisions.
688
+ threshold = self.get_effective_threshold(obj_id)
689
+ low_cut = max(0.0, threshold - self.hysteresis_margin)
690
+ high_cut = min(1.0, threshold + self.hysteresis_margin)
691
+ in_low_state = self.low_state.get(obj_id, False)
692
+ if in_low_state:
693
+ should_use = score >= high_cut
694
+ else:
695
+ should_use = score >= low_cut
696
+ self.low_state[obj_id] = not should_use
697
+
698
+ # Track consecutive low scores for drift detection
699
+ if not should_use:
700
+ self.consecutive_low_count[obj_id] = self.consecutive_low_count.get(obj_id, 0) + 1
701
+ else:
702
+ self.consecutive_low_count[obj_id] = 0
703
+
704
+ # Flag for correction if too many consecutive low scores
705
+ if self.consecutive_low_count.get(obj_id, 0) >= self.consecutive_low_threshold:
706
+ self.needs_correction[obj_id] = True
707
+ else:
708
+ self.needs_correction[obj_id] = False
709
+
710
+ # Long-term appearance memory: store golden masks + detect occlusion recovery
711
+ if self.appearance_memory is not None:
712
+ self.appearance_memory.store_if_golden(
713
+ obj_id, mask, actual_bbox, mask_area, score, frame_idx
714
+ )
715
+ self.appearance_memory.update_occlusion_state(obj_id, mask_area, frame_idx)
716
+
717
+ return score, should_use
718
+
719
+ def _prune_history(self, obj_id):
720
+ """Keep only recent history to cap memory growth."""
721
+ if self.max_history_frames is None or self.max_history_frames <= 0:
722
+ return
723
+ for store in (self.scores.get(obj_id), self.bboxes.get(obj_id), self.areas.get(obj_id)):
724
+ if not store or len(store) <= self.max_history_frames:
725
+ continue
726
+ trim_count = len(store) - self.max_history_frames
727
+ for frame_idx in sorted(store.keys())[:trim_count]:
728
+ del store[frame_idx]
729
+
730
+ def check_needs_correction(self, obj_id):
731
+ """Check if object needs prompt correction due to drift."""
732
+ return self.needs_correction.get(obj_id, False)
733
+
734
+ def reset_correction_flag(self, obj_id):
735
+ """Reset correction flag after applying correction."""
736
+ self.needs_correction[obj_id] = False
737
+ self.consecutive_low_count[obj_id] = 0
738
+ self.low_state[obj_id] = False
739
+
740
+ def get_low_score_frames(self, obj_id, threshold=None):
741
+ """Get list of frames with scores below threshold."""
742
+ if threshold is None:
743
+ threshold = self.motion_score_threshold
744
+ if obj_id not in self.scores:
745
+ return []
746
+ return [f for f, s in self.scores[obj_id].items() if s < threshold]
747
+
748
+ def get_recent_scores(self, obj_id, n_frames=10):
749
+ """Get scores for last n frames."""
750
+ if obj_id not in self.scores:
751
+ return []
752
+ frames = sorted(self.scores[obj_id].keys())[-n_frames:]
753
+ return [(f, self.scores[obj_id][f]) for f in frames]
754
+
755
+ def get_best_memory_frames(self, obj_id, n_frames=6):
756
+ """
757
+ Get the N highest-scoring frames for memory bank prioritization.
758
+ Returns list of (frame_idx, score) tuples.
759
+ """
760
+ if obj_id not in self.scores:
761
+ return []
762
+ sorted_frames = sorted(self.scores[obj_id].items(), key=lambda x: x[1], reverse=True)
763
+ return sorted_frames[:n_frames]
764
+