singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,764 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Motion-Aware Tracking Enhancements for SAM2.
|
|
3
|
+
|
|
4
|
+
Includes:
|
|
5
|
+
- Kalman filter-based motion prediction (SAMURAI-style)
|
|
6
|
+
- OC-SORT drift correction (virtual trajectory + ORU)
|
|
7
|
+
- Mask quality scoring and temporal consistency checks
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class KalmanBoxTracker:
|
|
19
|
+
"""
|
|
20
|
+
Kalman filter for tracking bounding boxes with OC-SORT enhancements.
|
|
21
|
+
State: [x_center, y_center, scale, aspect_ratio, dx, dy, d_scale]
|
|
22
|
+
|
|
23
|
+
OC-SORT features:
|
|
24
|
+
- Virtual trajectory: During occlusion, generate virtual observations from velocity
|
|
25
|
+
- ORU (Observation-Centric Re-Update): When object reappears, correct past states
|
|
26
|
+
"""
|
|
27
|
+
def __init__(self, bbox, delta_t=1, inertia=0.2):
|
|
28
|
+
"""
|
|
29
|
+
Initialize tracker with bounding box [x1, y1, x2, y2].
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
bbox: Initial bounding box [x1, y1, x2, y2]
|
|
33
|
+
delta_t: Frame interval for virtual trajectory (OC-SORT paper: 1)
|
|
34
|
+
inertia: Velocity smoothing factor for ORU (OC-SORT paper: 0.2)
|
|
35
|
+
"""
|
|
36
|
+
from filterpy.kalman import KalmanFilter
|
|
37
|
+
|
|
38
|
+
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
|
39
|
+
self.delta_t = delta_t
|
|
40
|
+
self.inertia = inertia
|
|
41
|
+
|
|
42
|
+
# State transition matrix (constant velocity model)
|
|
43
|
+
self.kf.F = np.array([
|
|
44
|
+
[1, 0, 0, 0, 1, 0, 0],
|
|
45
|
+
[0, 1, 0, 0, 0, 1, 0],
|
|
46
|
+
[0, 0, 1, 0, 0, 0, 1],
|
|
47
|
+
[0, 0, 0, 1, 0, 0, 0],
|
|
48
|
+
[0, 0, 0, 0, 1, 0, 0],
|
|
49
|
+
[0, 0, 0, 0, 0, 1, 0],
|
|
50
|
+
[0, 0, 0, 0, 0, 0, 1]
|
|
51
|
+
], dtype=np.float32)
|
|
52
|
+
|
|
53
|
+
# Measurement matrix
|
|
54
|
+
self.kf.H = np.array([
|
|
55
|
+
[1, 0, 0, 0, 0, 0, 0],
|
|
56
|
+
[0, 1, 0, 0, 0, 0, 0],
|
|
57
|
+
[0, 0, 1, 0, 0, 0, 0],
|
|
58
|
+
[0, 0, 0, 1, 0, 0, 0]
|
|
59
|
+
], dtype=np.float32)
|
|
60
|
+
|
|
61
|
+
# Measurement noise covariance
|
|
62
|
+
self.kf.R *= 10.0
|
|
63
|
+
self.kf.R[2, 2] *= 10.0 # Scale measurement has higher noise
|
|
64
|
+
|
|
65
|
+
# Process noise covariance
|
|
66
|
+
self.kf.P[4:, 4:] *= 1000.0 # High uncertainty in velocities initially
|
|
67
|
+
self.kf.P *= 10.0
|
|
68
|
+
|
|
69
|
+
self.kf.Q[-1, -1] *= 0.01
|
|
70
|
+
self.kf.Q[4:, 4:] *= 0.01
|
|
71
|
+
|
|
72
|
+
self.kf.x[:4] = self._bbox_to_z(bbox).reshape(4, 1)
|
|
73
|
+
|
|
74
|
+
self.time_since_update = 0
|
|
75
|
+
self.history = []
|
|
76
|
+
self.hits = 0
|
|
77
|
+
self.age = 0
|
|
78
|
+
|
|
79
|
+
# OC-SORT: Store observations for ORU
|
|
80
|
+
self.observations = {} # frame_idx -> observation (z)
|
|
81
|
+
self.last_observation_frame = 0
|
|
82
|
+
self.last_observation = self._bbox_to_z(bbox)
|
|
83
|
+
self.frozen_velocity = None # Velocity frozen at occlusion start
|
|
84
|
+
|
|
85
|
+
def _bbox_to_z(self, bbox):
|
|
86
|
+
"""Convert [x1, y1, x2, y2] to [x_center, y_center, scale, aspect_ratio]."""
|
|
87
|
+
x1, y1, x2, y2 = bbox
|
|
88
|
+
w = x2 - x1
|
|
89
|
+
h = y2 - y1
|
|
90
|
+
x_c = x1 + w / 2.0
|
|
91
|
+
y_c = y1 + h / 2.0
|
|
92
|
+
s = w * h # scale (area)
|
|
93
|
+
r = w / (h + 1e-6) # aspect ratio
|
|
94
|
+
return np.array([x_c, y_c, s, r], dtype=np.float32)
|
|
95
|
+
|
|
96
|
+
def _z_to_bbox(self, z):
|
|
97
|
+
"""Convert [x_center, y_center, scale, aspect_ratio] to [x1, y1, x2, y2]."""
|
|
98
|
+
x_c, y_c, s, r = z.flatten()[:4]
|
|
99
|
+
w = np.sqrt(max(s * r, 1.0))
|
|
100
|
+
h = s / (w + 1e-6)
|
|
101
|
+
return np.array([x_c - w/2, y_c - h/2, x_c + w/2, y_c + h/2], dtype=np.float32)
|
|
102
|
+
|
|
103
|
+
def predict(self):
|
|
104
|
+
"""Predict next state and return predicted bbox."""
|
|
105
|
+
if self.kf.x[6] + self.kf.x[2] <= 0:
|
|
106
|
+
self.kf.x[6] = 0.0
|
|
107
|
+
self.kf.predict()
|
|
108
|
+
self.age += 1
|
|
109
|
+
self.time_since_update += 1
|
|
110
|
+
self.history.append(self._z_to_bbox(self.kf.x))
|
|
111
|
+
return self.history[-1]
|
|
112
|
+
|
|
113
|
+
def predict_with_virtual_trajectory(self, frame_idx):
|
|
114
|
+
"""
|
|
115
|
+
OC-SORT: Predict with virtual trajectory during occlusion.
|
|
116
|
+
Uses last known velocity to generate virtual observations.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
frame_idx: Current frame index
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Predicted bbox
|
|
123
|
+
"""
|
|
124
|
+
if self.time_since_update == 0:
|
|
125
|
+
# Not occluded, use normal prediction
|
|
126
|
+
return self.predict()
|
|
127
|
+
|
|
128
|
+
# Freeze velocity at start of occlusion
|
|
129
|
+
if self.frozen_velocity is None:
|
|
130
|
+
self.frozen_velocity = self.kf.x[4:7].copy()
|
|
131
|
+
|
|
132
|
+
# Generate virtual observation using frozen velocity
|
|
133
|
+
virtual_z = self.last_observation.copy()
|
|
134
|
+
dt = frame_idx - self.last_observation_frame
|
|
135
|
+
if dt > 0:
|
|
136
|
+
virtual_z[0] += self.frozen_velocity[0, 0] * dt # x_center
|
|
137
|
+
virtual_z[1] += self.frozen_velocity[1, 0] * dt # y_center
|
|
138
|
+
virtual_z[2] += self.frozen_velocity[2, 0] * dt # scale
|
|
139
|
+
virtual_z[2] = max(virtual_z[2], 1.0) # scale must stay positive
|
|
140
|
+
|
|
141
|
+
# Fold the virtual observation into the Kalman state so predicted and
|
|
142
|
+
# observed trajectories stay in sync during occlusions.
|
|
143
|
+
self.kf.predict()
|
|
144
|
+
self.kf.update(virtual_z)
|
|
145
|
+
self.age += 1
|
|
146
|
+
self.time_since_update += 1
|
|
147
|
+
|
|
148
|
+
bbox = self._z_to_bbox(self.kf.x)
|
|
149
|
+
self.history.append(bbox)
|
|
150
|
+
return bbox
|
|
151
|
+
|
|
152
|
+
def update(self, bbox, frame_idx=None):
|
|
153
|
+
"""Update state with observed bbox."""
|
|
154
|
+
z = self._bbox_to_z(bbox)
|
|
155
|
+
|
|
156
|
+
# OC-SORT: Apply ORU if recovering from occlusion
|
|
157
|
+
if self.time_since_update > 0 and frame_idx is not None:
|
|
158
|
+
self._apply_oru(z, frame_idx)
|
|
159
|
+
|
|
160
|
+
self.time_since_update = 0
|
|
161
|
+
self.history = []
|
|
162
|
+
self.hits += 1
|
|
163
|
+
self.kf.update(z)
|
|
164
|
+
|
|
165
|
+
# Store observation for future ORU
|
|
166
|
+
if frame_idx is not None:
|
|
167
|
+
self.observations[frame_idx] = z.copy()
|
|
168
|
+
self.last_observation_frame = frame_idx
|
|
169
|
+
self.last_observation = z.copy()
|
|
170
|
+
self.frozen_velocity = None # Reset frozen velocity
|
|
171
|
+
|
|
172
|
+
def _apply_oru(self, new_z, frame_idx):
|
|
173
|
+
"""
|
|
174
|
+
OC-SORT: Observation-Centric Re-Update.
|
|
175
|
+
When object reappears after occlusion, interpolate backwards
|
|
176
|
+
to correct past Kalman state drift.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
new_z: New observation [x_c, y_c, s, r]
|
|
180
|
+
frame_idx: Current frame index
|
|
181
|
+
"""
|
|
182
|
+
if self.time_since_update <= 1:
|
|
183
|
+
return # ORU only matters after more than one missed frame
|
|
184
|
+
|
|
185
|
+
dt = frame_idx - self.last_observation_frame
|
|
186
|
+
if dt <= 0:
|
|
187
|
+
return
|
|
188
|
+
|
|
189
|
+
velocity = (new_z - self.last_observation) / dt
|
|
190
|
+
|
|
191
|
+
# Apply velocity smoothing with inertia
|
|
192
|
+
# Mix new velocity with old frozen velocity
|
|
193
|
+
if self.frozen_velocity is not None:
|
|
194
|
+
old_vel = self.frozen_velocity.flatten()[:3]
|
|
195
|
+
new_vel = velocity[:3]
|
|
196
|
+
smoothed_vel = self.inertia * old_vel + (1 - self.inertia) * new_vel
|
|
197
|
+
else:
|
|
198
|
+
smoothed_vel = velocity[:3]
|
|
199
|
+
|
|
200
|
+
# Re-update state with corrected velocity
|
|
201
|
+
self.kf.x[4, 0] = smoothed_vel[0] # dx
|
|
202
|
+
self.kf.x[5, 0] = smoothed_vel[1] # dy
|
|
203
|
+
self.kf.x[6, 0] = smoothed_vel[2] # d_scale
|
|
204
|
+
|
|
205
|
+
# Reduce uncertainty since we now have a good observation
|
|
206
|
+
self.kf.P[4:, 4:] *= 0.5
|
|
207
|
+
|
|
208
|
+
def get_state(self):
|
|
209
|
+
"""Get current bbox estimate."""
|
|
210
|
+
return self._z_to_bbox(self.kf.x)
|
|
211
|
+
|
|
212
|
+
def get_velocity(self):
|
|
213
|
+
"""Get current velocity estimate [dx, dy, d_scale]."""
|
|
214
|
+
return self.kf.x[4:7].flatten().copy()
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def mask_to_bbox(mask):
|
|
218
|
+
"""Convert binary mask to bounding box [x1, y1, x2, y2]."""
|
|
219
|
+
if mask is None or mask.max() == 0:
|
|
220
|
+
return None
|
|
221
|
+
ys, xs = np.where(mask > 0)
|
|
222
|
+
if len(ys) == 0:
|
|
223
|
+
return None
|
|
224
|
+
# Use [x1, y1, x2, y2) convention (exclusive max corner) for consistent area math.
|
|
225
|
+
return np.array([xs.min(), ys.min(), xs.max() + 1, ys.max() + 1], dtype=np.float32)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def compute_iou(bbox1, bbox2):
|
|
229
|
+
"""Compute IoU between two bboxes [x1, y1, x2, y2]."""
|
|
230
|
+
if bbox1 is None or bbox2 is None:
|
|
231
|
+
return 0.0
|
|
232
|
+
|
|
233
|
+
x1 = max(bbox1[0], bbox2[0])
|
|
234
|
+
y1 = max(bbox1[1], bbox2[1])
|
|
235
|
+
x2 = min(bbox1[2], bbox2[2])
|
|
236
|
+
y2 = min(bbox1[3], bbox2[3])
|
|
237
|
+
|
|
238
|
+
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
|
|
239
|
+
|
|
240
|
+
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
|
241
|
+
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
|
242
|
+
|
|
243
|
+
union_area = area1 + area2 - inter_area
|
|
244
|
+
|
|
245
|
+
if union_area <= 0:
|
|
246
|
+
return 0.0
|
|
247
|
+
|
|
248
|
+
return inter_area / union_area
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def compute_mask_score(mask_logit, predicted_bbox=None, actual_bbox=None,
|
|
252
|
+
use_multiplicative=True):
|
|
253
|
+
"""
|
|
254
|
+
Compute quality score for a mask prediction (SAMURAI-style).
|
|
255
|
+
|
|
256
|
+
SAMURAI formula: score = sigmoid(obj_score) * sigmoid(IoU)
|
|
257
|
+
This multiplicative approach is stricter - if either component is low,
|
|
258
|
+
the whole score drops significantly.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
mask_logit: Raw mask logits from SAM2
|
|
262
|
+
predicted_bbox: Predicted bbox from Kalman filter
|
|
263
|
+
actual_bbox: Actual bbox from mask
|
|
264
|
+
use_multiplicative: If True, use SAMURAI's multiplicative scoring.
|
|
265
|
+
If False, use additive (average).
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
score: Quality score between 0 and 1
|
|
269
|
+
obj_score: Objectness component
|
|
270
|
+
motion_iou: Motion consistency component
|
|
271
|
+
"""
|
|
272
|
+
# Objectness score: max confidence in the mask (already sigmoid for logits)
|
|
273
|
+
if isinstance(mask_logit, torch.Tensor):
|
|
274
|
+
obj_score = torch.sigmoid(mask_logit).max().item()
|
|
275
|
+
else:
|
|
276
|
+
obj_score = 1.0 / (1.0 + np.exp(-mask_logit.max()))
|
|
277
|
+
|
|
278
|
+
# Motion consistency score (IoU between predicted and actual bbox)
|
|
279
|
+
if predicted_bbox is not None and actual_bbox is not None:
|
|
280
|
+
motion_iou = compute_iou(predicted_bbox, actual_bbox)
|
|
281
|
+
else:
|
|
282
|
+
motion_iou = 1.0 # No motion prediction available
|
|
283
|
+
|
|
284
|
+
# SAMURAI applies sigmoid to IoU as well for normalization
|
|
285
|
+
# sigmoid(IoU) maps [0,1] -> [0.5, 0.73] roughly, so we use a scaled sigmoid
|
|
286
|
+
# that maps 0->0, 0.5->0.5, 1->1 for better range
|
|
287
|
+
def soft_sigmoid(x, k=5.0):
|
|
288
|
+
"""Soft sigmoid that maps [0,1] to [0,1] with 0.5->0.5"""
|
|
289
|
+
return 1.0 / (1.0 + np.exp(-k * (x - 0.5)))
|
|
290
|
+
|
|
291
|
+
if use_multiplicative:
|
|
292
|
+
# SAMURAI-style: multiplicative scoring (stricter)
|
|
293
|
+
# Both components must be high for good score
|
|
294
|
+
normalized_iou = soft_sigmoid(motion_iou)
|
|
295
|
+
score = obj_score * normalized_iou
|
|
296
|
+
else:
|
|
297
|
+
# Additive scoring (more lenient)
|
|
298
|
+
score = 0.5 * obj_score + 0.5 * motion_iou
|
|
299
|
+
|
|
300
|
+
return score, obj_score, motion_iou
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class AppearanceMemoryBank:
|
|
304
|
+
"""
|
|
305
|
+
Long-term appearance memory that stores high-quality "golden" masks per object.
|
|
306
|
+
When an object is lost for many frames (long occlusion) and starts reappearing,
|
|
307
|
+
the best stored mask is used to re-seed SAM2 so appearance is recovered.
|
|
308
|
+
"""
|
|
309
|
+
def __init__(self, max_snapshots=5, min_score_to_store=0.6,
|
|
310
|
+
occlusion_enter_frames=10, recovery_area_ratio=0.05,
|
|
311
|
+
reseed_debounce_frames=10, shape_area_ratio_range=(0.7, 1.3),
|
|
312
|
+
max_aspect_ratio_change=0.35):
|
|
313
|
+
"""
|
|
314
|
+
Args:
|
|
315
|
+
max_snapshots: max golden masks kept per object (best-scoring kept)
|
|
316
|
+
min_score_to_store: minimum quality score to consider a frame "golden"
|
|
317
|
+
occlusion_enter_frames: consecutive zero/low frames before entering occlusion state
|
|
318
|
+
recovery_area_ratio: when current area / golden area >= this, consider partial
|
|
319
|
+
recovery. Set very low (0.05 = 5%) so even a sliver of tail triggers reseed.
|
|
320
|
+
reseed_debounce_frames: min frames between consecutive re-seeds for same object
|
|
321
|
+
shape_area_ratio_range: accepted candidate area ratio range vs learned shape prior
|
|
322
|
+
max_aspect_ratio_change: max relative aspect-ratio change vs learned shape prior
|
|
323
|
+
"""
|
|
324
|
+
self.max_snapshots = max_snapshots
|
|
325
|
+
self.min_score_to_store = min_score_to_store
|
|
326
|
+
self.occlusion_enter_frames = occlusion_enter_frames
|
|
327
|
+
self.recovery_area_ratio = recovery_area_ratio
|
|
328
|
+
self.reseed_debounce_frames = reseed_debounce_frames
|
|
329
|
+
self.shape_area_ratio_range = shape_area_ratio_range
|
|
330
|
+
self.max_aspect_ratio_change = max_aspect_ratio_change
|
|
331
|
+
|
|
332
|
+
# Per-object storage
|
|
333
|
+
self.snapshots = {} # obj_id -> [(score, frame_idx, mask, bbox, area)]
|
|
334
|
+
self.golden_area = {} # obj_id -> median area from golden masks
|
|
335
|
+
self.golden_aspect = {} # obj_id -> median aspect ratio from golden masks
|
|
336
|
+
self.zero_streak = {} # obj_id -> consecutive frames with empty/tiny mask
|
|
337
|
+
self.occluded = {} # obj_id -> bool
|
|
338
|
+
self.recovery_pending = {} # obj_id -> bool (just left occlusion, needs re-seed)
|
|
339
|
+
self.last_reseed_frame = {} # obj_id -> frame_idx of last re-seed (debounce)
|
|
340
|
+
|
|
341
|
+
@staticmethod
|
|
342
|
+
def _bbox_aspect_ratio(bbox):
|
|
343
|
+
if bbox is None:
|
|
344
|
+
return None
|
|
345
|
+
w = max(float(bbox[2] - bbox[0]), 1.0)
|
|
346
|
+
h = max(float(bbox[3] - bbox[1]), 1.0)
|
|
347
|
+
return w / h
|
|
348
|
+
|
|
349
|
+
def _passes_shape_guard(self, obj_id, area, bbox):
|
|
350
|
+
"""Keep golden masks close to the learned prompt-shape profile."""
|
|
351
|
+
ref_area = self.golden_area.get(obj_id)
|
|
352
|
+
if ref_area is not None and ref_area > 0:
|
|
353
|
+
ratio = float(area) / float(ref_area)
|
|
354
|
+
lo, hi = self.shape_area_ratio_range
|
|
355
|
+
if ratio < lo or ratio > hi:
|
|
356
|
+
return False
|
|
357
|
+
|
|
358
|
+
ref_aspect = self.golden_aspect.get(obj_id)
|
|
359
|
+
cur_aspect = self._bbox_aspect_ratio(bbox)
|
|
360
|
+
if ref_aspect is not None and ref_aspect > 0 and cur_aspect is not None:
|
|
361
|
+
rel_change = abs(cur_aspect - ref_aspect) / ref_aspect
|
|
362
|
+
if rel_change > self.max_aspect_ratio_change:
|
|
363
|
+
return False
|
|
364
|
+
return True
|
|
365
|
+
|
|
366
|
+
def store_if_golden(self, obj_id, mask, bbox, area, score, frame_idx):
|
|
367
|
+
"""Store mask snapshot if quality is high enough."""
|
|
368
|
+
if score < self.min_score_to_store or mask is None or area < 1:
|
|
369
|
+
return
|
|
370
|
+
if not self._passes_shape_guard(obj_id, area, bbox):
|
|
371
|
+
return
|
|
372
|
+
if obj_id not in self.snapshots:
|
|
373
|
+
self.snapshots[obj_id] = []
|
|
374
|
+
|
|
375
|
+
entry = (score, frame_idx, mask.copy(), bbox.copy() if bbox is not None else None, area)
|
|
376
|
+
snaps = self.snapshots[obj_id]
|
|
377
|
+
snaps.append(entry)
|
|
378
|
+
snaps.sort(key=lambda e: e[0], reverse=True)
|
|
379
|
+
if len(snaps) > self.max_snapshots:
|
|
380
|
+
snaps[:] = snaps[:self.max_snapshots]
|
|
381
|
+
|
|
382
|
+
# Refresh the reference ("golden") area as the median across snapshots.
|
|
383
|
+
areas = [s[4] for s in snaps]
|
|
384
|
+
self.golden_area[obj_id] = float(np.median(areas))
|
|
385
|
+
aspects = [self._bbox_aspect_ratio(s[3]) for s in snaps if s[3] is not None]
|
|
386
|
+
if aspects:
|
|
387
|
+
self.golden_aspect[obj_id] = float(np.median(aspects))
|
|
388
|
+
|
|
389
|
+
def update_occlusion_state(self, obj_id, mask_area, frame_idx):
|
|
390
|
+
"""
|
|
391
|
+
Track whether object is in long occlusion and detect recovery.
|
|
392
|
+
Returns True if a re-seed should happen this frame.
|
|
393
|
+
"""
|
|
394
|
+
golden = self.golden_area.get(obj_id)
|
|
395
|
+
if golden is None or golden < 1:
|
|
396
|
+
return False
|
|
397
|
+
|
|
398
|
+
area_fraction = mask_area / golden
|
|
399
|
+
|
|
400
|
+
# Object is mostly gone?
|
|
401
|
+
if area_fraction < 0.1:
|
|
402
|
+
self.zero_streak[obj_id] = self.zero_streak.get(obj_id, 0) + 1
|
|
403
|
+
else:
|
|
404
|
+
self.zero_streak[obj_id] = 0
|
|
405
|
+
|
|
406
|
+
was_occluded = self.occluded.get(obj_id, False)
|
|
407
|
+
|
|
408
|
+
# Enter occlusion mode after sustained absence
|
|
409
|
+
if self.zero_streak.get(obj_id, 0) >= self.occlusion_enter_frames:
|
|
410
|
+
self.occluded[obj_id] = True
|
|
411
|
+
|
|
412
|
+
# Detect recovery: was occluded, now partial mask is back
|
|
413
|
+
if was_occluded and area_fraction >= self.recovery_area_ratio:
|
|
414
|
+
self.occluded[obj_id] = False
|
|
415
|
+
self.zero_streak[obj_id] = 0
|
|
416
|
+
|
|
417
|
+
# Debounce: don't re-seed too frequently
|
|
418
|
+
last = self.last_reseed_frame.get(obj_id, -999)
|
|
419
|
+
if frame_idx - last >= self.reseed_debounce_frames:
|
|
420
|
+
self.recovery_pending[obj_id] = True
|
|
421
|
+
self.last_reseed_frame[obj_id] = frame_idx
|
|
422
|
+
return True
|
|
423
|
+
return False
|
|
424
|
+
|
|
425
|
+
def pop_reseed_mask(self, obj_id):
|
|
426
|
+
"""
|
|
427
|
+
Return the best golden mask for re-seeding, or None.
|
|
428
|
+
Clears the pending flag.
|
|
429
|
+
"""
|
|
430
|
+
self.recovery_pending.pop(obj_id, None)
|
|
431
|
+
snaps = self.snapshots.get(obj_id)
|
|
432
|
+
if not snaps:
|
|
433
|
+
return None
|
|
434
|
+
return snaps[0][2] # highest-score snapshot mask array
|
|
435
|
+
|
|
436
|
+
def is_recovery_pending(self, obj_id):
|
|
437
|
+
return self.recovery_pending.get(obj_id, False)
|
|
438
|
+
|
|
439
|
+
def has_snapshots(self, obj_id):
|
|
440
|
+
return bool(self.snapshots.get(obj_id))
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class MultiObjectMotionTracker:
|
|
444
|
+
"""
|
|
445
|
+
Manages Kalman filter trackers for multiple objects with
|
|
446
|
+
motion-aware tracking, drift detection, and automatic prompt injection.
|
|
447
|
+
|
|
448
|
+
Supports OC-SORT enhancements:
|
|
449
|
+
- Virtual trajectory during occlusions
|
|
450
|
+
- ORU (Observation-Centric Re-Update) for drift correction
|
|
451
|
+
"""
|
|
452
|
+
def __init__(self, motion_score_threshold=0.3, use_kalman=True,
|
|
453
|
+
consecutive_low_threshold=3, area_change_threshold=0.5,
|
|
454
|
+
use_ocsort=False, ocsort_inertia=0.2, max_history_frames=1000,
|
|
455
|
+
adaptive_threshold=True, threshold_window=30, hysteresis_margin=0.05,
|
|
456
|
+
max_correction_jump_px=80.0, max_correction_area_ratio=2.5,
|
|
457
|
+
enable_appearance_memory=True, appearance_min_score=0.6,
|
|
458
|
+
appearance_max_snapshots=5, occlusion_enter_frames=5,
|
|
459
|
+
recovery_area_ratio=0.15, reseed_debounce_frames=5,
|
|
460
|
+
shape_area_ratio_range=(0.7, 1.3), max_aspect_ratio_change=0.35):
|
|
461
|
+
self.trackers = {} # obj_id -> KalmanBoxTracker
|
|
462
|
+
self.scores = {} # obj_id -> {frame_idx: score}
|
|
463
|
+
self.bboxes = {} # obj_id -> {frame_idx: bbox} for temporal consistency
|
|
464
|
+
self.areas = {} # obj_id -> {frame_idx: area}
|
|
465
|
+
self.motion_score_threshold = motion_score_threshold
|
|
466
|
+
self.use_kalman = use_kalman
|
|
467
|
+
self.kalman_available = True
|
|
468
|
+
self.consecutive_low_threshold = consecutive_low_threshold # Frames before recovery
|
|
469
|
+
self.area_change_threshold = area_change_threshold # Max allowed area change ratio
|
|
470
|
+
self.consecutive_low_count = {} # obj_id -> count of consecutive low scores
|
|
471
|
+
self.needs_correction = {} # obj_id -> bool
|
|
472
|
+
|
|
473
|
+
# OC-SORT parameters
|
|
474
|
+
self.use_ocsort = use_ocsort
|
|
475
|
+
self.ocsort_inertia = ocsort_inertia # Velocity smoothing (paper: 0.2)
|
|
476
|
+
self.max_history_frames = max_history_frames
|
|
477
|
+
self.adaptive_threshold = adaptive_threshold
|
|
478
|
+
self.threshold_window = threshold_window
|
|
479
|
+
self.hysteresis_margin = hysteresis_margin
|
|
480
|
+
self.max_correction_jump_px = max_correction_jump_px
|
|
481
|
+
self.max_correction_area_ratio = max_correction_area_ratio
|
|
482
|
+
self.low_state = {} # obj_id -> bool (hysteresis state)
|
|
483
|
+
|
|
484
|
+
# Long-term appearance memory for occlusion recovery
|
|
485
|
+
self.enable_appearance_memory = enable_appearance_memory
|
|
486
|
+
self.appearance_memory = AppearanceMemoryBank(
|
|
487
|
+
max_snapshots=appearance_max_snapshots,
|
|
488
|
+
min_score_to_store=appearance_min_score,
|
|
489
|
+
occlusion_enter_frames=occlusion_enter_frames,
|
|
490
|
+
recovery_area_ratio=recovery_area_ratio,
|
|
491
|
+
reseed_debounce_frames=reseed_debounce_frames,
|
|
492
|
+
shape_area_ratio_range=shape_area_ratio_range,
|
|
493
|
+
max_aspect_ratio_change=max_aspect_ratio_change,
|
|
494
|
+
) if enable_appearance_memory else None
|
|
495
|
+
|
|
496
|
+
try:
|
|
497
|
+
from filterpy.kalman import KalmanFilter
|
|
498
|
+
except ImportError:
|
|
499
|
+
self.kalman_available = False
|
|
500
|
+
self.use_kalman = False
|
|
501
|
+
|
|
502
|
+
def initialize_tracker(self, obj_id, bbox, frame_idx=0):
|
|
503
|
+
"""Initialize a new tracker for an object."""
|
|
504
|
+
if not self.use_kalman or not self.kalman_available:
|
|
505
|
+
return
|
|
506
|
+
if bbox is None:
|
|
507
|
+
return
|
|
508
|
+
try:
|
|
509
|
+
self.trackers[obj_id] = KalmanBoxTracker(
|
|
510
|
+
bbox,
|
|
511
|
+
delta_t=1,
|
|
512
|
+
inertia=self.ocsort_inertia
|
|
513
|
+
)
|
|
514
|
+
self.trackers[obj_id].last_observation_frame = frame_idx
|
|
515
|
+
self.scores[obj_id] = {}
|
|
516
|
+
self.bboxes[obj_id] = {}
|
|
517
|
+
self.areas[obj_id] = {}
|
|
518
|
+
self.consecutive_low_count[obj_id] = 0
|
|
519
|
+
self.needs_correction[obj_id] = False
|
|
520
|
+
self.low_state[obj_id] = False
|
|
521
|
+
except Exception as e:
|
|
522
|
+
logger.debug("Failed to initialize tracker for obj %s: %s", obj_id, e)
|
|
523
|
+
|
|
524
|
+
def predict(self, obj_id):
|
|
525
|
+
"""Return the current Kalman state estimate without advancing it."""
|
|
526
|
+
if obj_id not in self.trackers:
|
|
527
|
+
return None
|
|
528
|
+
try:
|
|
529
|
+
return self.trackers[obj_id].get_state()
|
|
530
|
+
except Exception as e:
|
|
531
|
+
logger.debug("Failed to get state for obj %s: %s", obj_id, e)
|
|
532
|
+
return None
|
|
533
|
+
|
|
534
|
+
def predict_and_advance(self, obj_id, frame_idx=None):
|
|
535
|
+
"""Advance the Kalman state by one step and return the predicted bbox.
|
|
536
|
+
|
|
537
|
+
When OC-SORT is enabled and a frame index is provided, the tracker
|
|
538
|
+
uses a virtual trajectory so occlusions do not zero out velocity.
|
|
539
|
+
"""
|
|
540
|
+
if obj_id not in self.trackers:
|
|
541
|
+
return None
|
|
542
|
+
try:
|
|
543
|
+
if self.use_ocsort and frame_idx is not None:
|
|
544
|
+
return self.trackers[obj_id].predict_with_virtual_trajectory(frame_idx)
|
|
545
|
+
else:
|
|
546
|
+
return self.trackers[obj_id].predict()
|
|
547
|
+
except Exception as e:
|
|
548
|
+
logger.debug("Failed to predict tracker for obj %s: %s", obj_id, e)
|
|
549
|
+
return None
|
|
550
|
+
|
|
551
|
+
def get_predicted_bbox_for_correction(self, obj_id):
|
|
552
|
+
"""Return the Kalman bbox as an [x1, y1, x2, y2] box prompt."""
|
|
553
|
+
if obj_id not in self.trackers:
|
|
554
|
+
return None
|
|
555
|
+
try:
|
|
556
|
+
bbox = self.trackers[obj_id].get_state()
|
|
557
|
+
if bbox is None:
|
|
558
|
+
return None
|
|
559
|
+
return bbox.tolist()
|
|
560
|
+
except Exception as e:
|
|
561
|
+
logger.debug("Failed to get correction bbox for obj %s: %s", obj_id, e)
|
|
562
|
+
return None
|
|
563
|
+
|
|
564
|
+
def get_frame_score(self, obj_id, frame_idx):
|
|
565
|
+
if obj_id not in self.scores:
|
|
566
|
+
return None
|
|
567
|
+
return self.scores[obj_id].get(frame_idx)
|
|
568
|
+
|
|
569
|
+
def get_effective_threshold(self, obj_id):
|
|
570
|
+
base = self.motion_score_threshold
|
|
571
|
+
if not self.adaptive_threshold or obj_id not in self.scores:
|
|
572
|
+
return base
|
|
573
|
+
frames = sorted(self.scores[obj_id].keys())[-self.threshold_window:]
|
|
574
|
+
if len(frames) < 5:
|
|
575
|
+
return base
|
|
576
|
+
vals = np.array([self.scores[obj_id][f] for f in frames], dtype=np.float32)
|
|
577
|
+
mean = float(vals.mean())
|
|
578
|
+
std = float(vals.std())
|
|
579
|
+
adaptive = mean - 0.5 * std
|
|
580
|
+
lower = max(0.05, base - 0.15)
|
|
581
|
+
upper = min(0.95, base + 0.15)
|
|
582
|
+
return float(np.clip(adaptive, lower, upper))
|
|
583
|
+
|
|
584
|
+
def is_correction_bbox_sane(self, obj_id, pred_bbox):
|
|
585
|
+
if pred_bbox is None or obj_id not in self.bboxes or not self.bboxes[obj_id]:
|
|
586
|
+
return False
|
|
587
|
+
last_frame = max(self.bboxes[obj_id].keys())
|
|
588
|
+
last_bbox = self.bboxes[obj_id][last_frame]
|
|
589
|
+
if last_bbox is None:
|
|
590
|
+
return False
|
|
591
|
+
|
|
592
|
+
pred_bbox = np.asarray(pred_bbox, dtype=np.float32)
|
|
593
|
+
last_bbox = np.asarray(last_bbox, dtype=np.float32)
|
|
594
|
+
|
|
595
|
+
pred_center = np.array(
|
|
596
|
+
[(pred_bbox[0] + pred_bbox[2]) * 0.5, (pred_bbox[1] + pred_bbox[3]) * 0.5],
|
|
597
|
+
dtype=np.float32,
|
|
598
|
+
)
|
|
599
|
+
last_center = np.array(
|
|
600
|
+
[(last_bbox[0] + last_bbox[2]) * 0.5, (last_bbox[1] + last_bbox[3]) * 0.5],
|
|
601
|
+
dtype=np.float32,
|
|
602
|
+
)
|
|
603
|
+
center_shift = float(np.linalg.norm(pred_center - last_center))
|
|
604
|
+
if center_shift > self.max_correction_jump_px:
|
|
605
|
+
return False
|
|
606
|
+
|
|
607
|
+
pred_area = max((pred_bbox[2] - pred_bbox[0]) * (pred_bbox[3] - pred_bbox[1]), 1.0)
|
|
608
|
+
last_area = max((last_bbox[2] - last_bbox[0]) * (last_bbox[3] - last_bbox[1]), 1.0)
|
|
609
|
+
area_ratio = max(pred_area / last_area, last_area / pred_area)
|
|
610
|
+
return area_ratio <= self.max_correction_area_ratio
|
|
611
|
+
|
|
612
|
+
def update(self, obj_id, mask, mask_logit, frame_idx):
|
|
613
|
+
"""
|
|
614
|
+
Update tracker with observed mask and compute quality score.
|
|
615
|
+
|
|
616
|
+
Returns:
|
|
617
|
+
score: Quality score for this frame
|
|
618
|
+
should_use_for_memory: Whether this frame should be used for memory
|
|
619
|
+
"""
|
|
620
|
+
actual_bbox = mask_to_bbox(mask)
|
|
621
|
+
|
|
622
|
+
if actual_bbox is None:
|
|
623
|
+
if obj_id in self.consecutive_low_count:
|
|
624
|
+
self.consecutive_low_count[obj_id] += 1
|
|
625
|
+
# OC-SORT: Still predict with virtual trajectory to maintain state
|
|
626
|
+
if self.use_ocsort and obj_id in self.trackers:
|
|
627
|
+
try:
|
|
628
|
+
self.trackers[obj_id].predict_with_virtual_trajectory(frame_idx)
|
|
629
|
+
except Exception as e:
|
|
630
|
+
logger.debug(
|
|
631
|
+
"Failed virtual trajectory for obj %s at frame %s: %s",
|
|
632
|
+
obj_id,
|
|
633
|
+
frame_idx,
|
|
634
|
+
e,
|
|
635
|
+
)
|
|
636
|
+
return 0.0, False
|
|
637
|
+
|
|
638
|
+
mask_area = float(np.sum(mask > 0))
|
|
639
|
+
|
|
640
|
+
if obj_id not in self.trackers:
|
|
641
|
+
self.initialize_tracker(obj_id, actual_bbox, frame_idx)
|
|
642
|
+
if obj_id not in self.scores:
|
|
643
|
+
self.scores[obj_id] = {}
|
|
644
|
+
self.scores[obj_id][frame_idx] = 1.0
|
|
645
|
+
self.bboxes[obj_id] = {frame_idx: actual_bbox}
|
|
646
|
+
self.areas[obj_id] = {frame_idx: mask_area}
|
|
647
|
+
self.low_state[obj_id] = False
|
|
648
|
+
return 1.0, True
|
|
649
|
+
|
|
650
|
+
# OC-SORT virtual trajectory: advance Kalman state through occlusions
|
|
651
|
+
# using frame_idx as the time index.
|
|
652
|
+
predicted_bbox = self.predict_and_advance(obj_id, frame_idx) if self.use_kalman else None
|
|
653
|
+
|
|
654
|
+
score, obj_score, motion_iou = compute_mask_score(
|
|
655
|
+
mask_logit, predicted_bbox, actual_bbox
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
# Temporal consistency: large area jumps usually indicate mask drift,
|
|
659
|
+
# so penalise the score when the current area deviates from the
|
|
660
|
+
# running mean over the last five frames.
|
|
661
|
+
if obj_id in self.areas and self.areas[obj_id]:
|
|
662
|
+
recent_frames = sorted(self.areas[obj_id].keys())[-5:]
|
|
663
|
+
if recent_frames:
|
|
664
|
+
avg_area = np.mean([self.areas[obj_id][f] for f in recent_frames])
|
|
665
|
+
if avg_area > 0:
|
|
666
|
+
area_ratio = mask_area / avg_area
|
|
667
|
+
if area_ratio < (1 - self.area_change_threshold) or \
|
|
668
|
+
area_ratio > (1 + self.area_change_threshold):
|
|
669
|
+
score *= 0.5
|
|
670
|
+
|
|
671
|
+
# OC-SORT Observation-Centric Re-Update: feed the real observation back
|
|
672
|
+
# into the tracker keyed on frame_idx.
|
|
673
|
+
if self.use_kalman and obj_id in self.trackers:
|
|
674
|
+
try:
|
|
675
|
+
self.trackers[obj_id].update(actual_bbox, frame_idx)
|
|
676
|
+
except Exception as e:
|
|
677
|
+
logger.debug("Failed to update tracker for obj %s at frame %s: %s", obj_id, frame_idx, e)
|
|
678
|
+
|
|
679
|
+
# Store data
|
|
680
|
+
if obj_id not in self.scores:
|
|
681
|
+
self.scores[obj_id] = {}
|
|
682
|
+
self.scores[obj_id][frame_idx] = score
|
|
683
|
+
self.bboxes[obj_id][frame_idx] = actual_bbox
|
|
684
|
+
self.areas[obj_id][frame_idx] = mask_area
|
|
685
|
+
self._prune_history(obj_id)
|
|
686
|
+
|
|
687
|
+
# Adaptive threshold + hysteresis to reduce flicker in keep/drop decisions.
|
|
688
|
+
threshold = self.get_effective_threshold(obj_id)
|
|
689
|
+
low_cut = max(0.0, threshold - self.hysteresis_margin)
|
|
690
|
+
high_cut = min(1.0, threshold + self.hysteresis_margin)
|
|
691
|
+
in_low_state = self.low_state.get(obj_id, False)
|
|
692
|
+
if in_low_state:
|
|
693
|
+
should_use = score >= high_cut
|
|
694
|
+
else:
|
|
695
|
+
should_use = score >= low_cut
|
|
696
|
+
self.low_state[obj_id] = not should_use
|
|
697
|
+
|
|
698
|
+
# Track consecutive low scores for drift detection
|
|
699
|
+
if not should_use:
|
|
700
|
+
self.consecutive_low_count[obj_id] = self.consecutive_low_count.get(obj_id, 0) + 1
|
|
701
|
+
else:
|
|
702
|
+
self.consecutive_low_count[obj_id] = 0
|
|
703
|
+
|
|
704
|
+
# Flag for correction if too many consecutive low scores
|
|
705
|
+
if self.consecutive_low_count.get(obj_id, 0) >= self.consecutive_low_threshold:
|
|
706
|
+
self.needs_correction[obj_id] = True
|
|
707
|
+
else:
|
|
708
|
+
self.needs_correction[obj_id] = False
|
|
709
|
+
|
|
710
|
+
# Long-term appearance memory: store golden masks + detect occlusion recovery
|
|
711
|
+
if self.appearance_memory is not None:
|
|
712
|
+
self.appearance_memory.store_if_golden(
|
|
713
|
+
obj_id, mask, actual_bbox, mask_area, score, frame_idx
|
|
714
|
+
)
|
|
715
|
+
self.appearance_memory.update_occlusion_state(obj_id, mask_area, frame_idx)
|
|
716
|
+
|
|
717
|
+
return score, should_use
|
|
718
|
+
|
|
719
|
+
def _prune_history(self, obj_id):
|
|
720
|
+
"""Keep only recent history to cap memory growth."""
|
|
721
|
+
if self.max_history_frames is None or self.max_history_frames <= 0:
|
|
722
|
+
return
|
|
723
|
+
for store in (self.scores.get(obj_id), self.bboxes.get(obj_id), self.areas.get(obj_id)):
|
|
724
|
+
if not store or len(store) <= self.max_history_frames:
|
|
725
|
+
continue
|
|
726
|
+
trim_count = len(store) - self.max_history_frames
|
|
727
|
+
for frame_idx in sorted(store.keys())[:trim_count]:
|
|
728
|
+
del store[frame_idx]
|
|
729
|
+
|
|
730
|
+
def check_needs_correction(self, obj_id):
|
|
731
|
+
"""Check if object needs prompt correction due to drift."""
|
|
732
|
+
return self.needs_correction.get(obj_id, False)
|
|
733
|
+
|
|
734
|
+
def reset_correction_flag(self, obj_id):
|
|
735
|
+
"""Reset correction flag after applying correction."""
|
|
736
|
+
self.needs_correction[obj_id] = False
|
|
737
|
+
self.consecutive_low_count[obj_id] = 0
|
|
738
|
+
self.low_state[obj_id] = False
|
|
739
|
+
|
|
740
|
+
def get_low_score_frames(self, obj_id, threshold=None):
|
|
741
|
+
"""Get list of frames with scores below threshold."""
|
|
742
|
+
if threshold is None:
|
|
743
|
+
threshold = self.motion_score_threshold
|
|
744
|
+
if obj_id not in self.scores:
|
|
745
|
+
return []
|
|
746
|
+
return [f for f, s in self.scores[obj_id].items() if s < threshold]
|
|
747
|
+
|
|
748
|
+
def get_recent_scores(self, obj_id, n_frames=10):
|
|
749
|
+
"""Get scores for last n frames."""
|
|
750
|
+
if obj_id not in self.scores:
|
|
751
|
+
return []
|
|
752
|
+
frames = sorted(self.scores[obj_id].keys())[-n_frames:]
|
|
753
|
+
return [(f, self.scores[obj_id][f]) for f in frames]
|
|
754
|
+
|
|
755
|
+
def get_best_memory_frames(self, obj_id, n_frames=6):
|
|
756
|
+
"""
|
|
757
|
+
Get the N highest-scoring frames for memory bank prioritization.
|
|
758
|
+
Returns list of (frame_idx, score) tuples.
|
|
759
|
+
"""
|
|
760
|
+
if obj_id not in self.scores:
|
|
761
|
+
return []
|
|
762
|
+
sorted_frames = sorted(self.scores[obj_id].items(), key=lambda x: x[1], reverse=True)
|
|
763
|
+
return sorted_frames[:n_frames]
|
|
764
|
+
|