singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional, List, Dict, Any
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AnnotationManager:
|
|
11
|
+
|
|
12
|
+
def __init__(self, annotation_file: str):
|
|
13
|
+
self.annotation_file = annotation_file
|
|
14
|
+
self.data = self._load_or_create()
|
|
15
|
+
|
|
16
|
+
def _load_or_create(self) -> dict:
|
|
17
|
+
if os.path.exists(self.annotation_file):
|
|
18
|
+
try:
|
|
19
|
+
with open(self.annotation_file, 'r') as f:
|
|
20
|
+
return json.load(f)
|
|
21
|
+
except Exception as e:
|
|
22
|
+
logger.warning("Error loading annotations: %s, creating new file", e)
|
|
23
|
+
|
|
24
|
+
return {
|
|
25
|
+
"clips": [],
|
|
26
|
+
"classes": []
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def _normalize_clip_id(self, clip_id: str) -> str:
|
|
30
|
+
"""Normalize clip IDs to match labeling list paths."""
|
|
31
|
+
normalized = (clip_id or "").replace('\\', '/')
|
|
32
|
+
if normalized.startswith("./"):
|
|
33
|
+
normalized = normalized[2:]
|
|
34
|
+
for prefix in ("../clips/", "clips/", "data/clips/"):
|
|
35
|
+
if normalized.startswith(prefix):
|
|
36
|
+
normalized = normalized[len(prefix):]
|
|
37
|
+
break
|
|
38
|
+
return normalized
|
|
39
|
+
|
|
40
|
+
def save(self):
|
|
41
|
+
"""Save annotations to file (with safe write)."""
|
|
42
|
+
os.makedirs(os.path.dirname(self.annotation_file), exist_ok=True)
|
|
43
|
+
|
|
44
|
+
temp_file = self.annotation_file + '.tmp'
|
|
45
|
+
try:
|
|
46
|
+
with open(temp_file, 'w') as f:
|
|
47
|
+
json.dump(self.data, f, indent=2)
|
|
48
|
+
os.replace(temp_file, self.annotation_file)
|
|
49
|
+
except Exception as e:
|
|
50
|
+
if os.path.exists(temp_file):
|
|
51
|
+
os.remove(temp_file)
|
|
52
|
+
raise e
|
|
53
|
+
|
|
54
|
+
def reload(self):
|
|
55
|
+
self.data = self._load_or_create()
|
|
56
|
+
|
|
57
|
+
def add_class(self, class_name: str):
|
|
58
|
+
if class_name not in self.data["classes"]:
|
|
59
|
+
self.data["classes"].append(class_name)
|
|
60
|
+
self.save()
|
|
61
|
+
|
|
62
|
+
def remove_class(self, class_name: str):
|
|
63
|
+
if class_name in self.data["classes"]:
|
|
64
|
+
self.data["classes"].remove(class_name)
|
|
65
|
+
self.save()
|
|
66
|
+
|
|
67
|
+
def rename_class(self, old_name: str, new_name: str):
|
|
68
|
+
"""Rename a behavior class and update all associated clips."""
|
|
69
|
+
if old_name not in self.data["classes"]:
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
if new_name not in self.data["classes"]:
|
|
73
|
+
idx = self.data["classes"].index(old_name)
|
|
74
|
+
self.data["classes"][idx] = new_name
|
|
75
|
+
else:
|
|
76
|
+
# New name already exists, merge
|
|
77
|
+
self.data["classes"].remove(old_name)
|
|
78
|
+
|
|
79
|
+
# Update clips (both label and labels fields)
|
|
80
|
+
for clip in self.data["clips"]:
|
|
81
|
+
if clip.get("label") == old_name:
|
|
82
|
+
clip["label"] = new_name
|
|
83
|
+
labels = clip.get("labels")
|
|
84
|
+
if isinstance(labels, list):
|
|
85
|
+
clip["labels"] = [new_name if l == old_name else l for l in labels]
|
|
86
|
+
if clip["labels"]:
|
|
87
|
+
clip["label"] = clip["labels"][0]
|
|
88
|
+
|
|
89
|
+
self.save()
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
def get_classes(self) -> List[str]:
|
|
93
|
+
return self.data["classes"].copy()
|
|
94
|
+
|
|
95
|
+
def add_clip(self, clip_id: str, label, meta: Optional[Dict[str, Any]] = None,
|
|
96
|
+
_defer_save: bool = False) -> str:
|
|
97
|
+
"""Add or update a clip annotation. label can be str or list of str.
|
|
98
|
+
Returns the clip id that was updated or created (for callers that need to set_frame_labels).
|
|
99
|
+
Set _defer_save=True when adding many clips in a loop, then call save() once at the end.
|
|
100
|
+
Ensures every label in labels_list is in the classes list so training sees no stray labels.
|
|
101
|
+
"""
|
|
102
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
103
|
+
labels_list = label if isinstance(label, list) else [label] if label else []
|
|
104
|
+
primary = labels_list[0] if labels_list else ""
|
|
105
|
+
|
|
106
|
+
for lbl in labels_list:
|
|
107
|
+
if lbl and lbl not in self.data["classes"]:
|
|
108
|
+
self.data["classes"].append(lbl)
|
|
109
|
+
|
|
110
|
+
for clip in self.data["clips"]:
|
|
111
|
+
if self._normalize_clip_id(clip["id"]) == clip_id_normalized:
|
|
112
|
+
clip["id"] = clip_id_normalized
|
|
113
|
+
clip["label"] = primary
|
|
114
|
+
clip["labels"] = labels_list
|
|
115
|
+
if meta:
|
|
116
|
+
clip.setdefault("meta", {})
|
|
117
|
+
clip["meta"].update(meta)
|
|
118
|
+
if not _defer_save:
|
|
119
|
+
self.save()
|
|
120
|
+
return clip_id_normalized
|
|
121
|
+
|
|
122
|
+
# No id match: if adding from inference (single-clip, not segment), try to update an
|
|
123
|
+
# existing unlabeled bulk-extracted clip for the same source video + start frame.
|
|
124
|
+
if meta and primary and isinstance(meta, dict) and not meta.get("added_from_inference_segment"):
|
|
125
|
+
src_video = meta.get("source_video")
|
|
126
|
+
src_frame = meta.get("source_frame")
|
|
127
|
+
if src_video is not None and src_frame is not None:
|
|
128
|
+
for clip in self.data["clips"]:
|
|
129
|
+
if clip.get("label"):
|
|
130
|
+
continue
|
|
131
|
+
cmeta = clip.get("meta") or {}
|
|
132
|
+
if cmeta.get("source_video") != src_video:
|
|
133
|
+
continue
|
|
134
|
+
if cmeta.get("sub_start_frame") is not None and cmeta.get("sub_start_frame") == src_frame:
|
|
135
|
+
clip["label"] = primary
|
|
136
|
+
clip["labels"] = labels_list
|
|
137
|
+
clip.setdefault("meta", {})
|
|
138
|
+
clip["meta"].update(meta)
|
|
139
|
+
if not _defer_save:
|
|
140
|
+
self.save()
|
|
141
|
+
return self._normalize_clip_id(clip["id"])
|
|
142
|
+
|
|
143
|
+
new_clip = {
|
|
144
|
+
"id": clip_id_normalized,
|
|
145
|
+
"label": primary,
|
|
146
|
+
"labels": labels_list,
|
|
147
|
+
"meta": meta or {}
|
|
148
|
+
}
|
|
149
|
+
self.data["clips"].append(new_clip)
|
|
150
|
+
if not _defer_save:
|
|
151
|
+
self.save()
|
|
152
|
+
return clip_id_normalized
|
|
153
|
+
|
|
154
|
+
def get_clip_labels(self, clip_id: str) -> List[str]:
|
|
155
|
+
"""Get all labels for a clip (multi-label aware). Falls back to single label."""
|
|
156
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
157
|
+
for clip in self.data["clips"]:
|
|
158
|
+
stored_id = self._normalize_clip_id(clip["id"])
|
|
159
|
+
if stored_id == clip_id_normalized:
|
|
160
|
+
labels = clip.get("labels")
|
|
161
|
+
if labels and isinstance(labels, list):
|
|
162
|
+
return list(labels)
|
|
163
|
+
lbl = clip.get("label")
|
|
164
|
+
return [lbl] if lbl else []
|
|
165
|
+
stored_base = os.path.splitext(stored_id)[0]
|
|
166
|
+
clip_base = os.path.splitext(clip_id_normalized)[0]
|
|
167
|
+
if stored_base == clip_base or stored_id == clip_base or clip_id_normalized == stored_base:
|
|
168
|
+
labels = clip.get("labels")
|
|
169
|
+
if labels and isinstance(labels, list):
|
|
170
|
+
return list(labels)
|
|
171
|
+
lbl = clip.get("label")
|
|
172
|
+
return [lbl] if lbl else []
|
|
173
|
+
return []
|
|
174
|
+
|
|
175
|
+
def set_spatial_mask(self, clip_id: str, patch_indices: List[int]):
|
|
176
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
177
|
+
for clip in self.data["clips"]:
|
|
178
|
+
if self._normalize_clip_id(clip["id"]) == clip_id_normalized:
|
|
179
|
+
clip["spatial_mask"] = sorted(patch_indices)
|
|
180
|
+
self.save()
|
|
181
|
+
return
|
|
182
|
+
|
|
183
|
+
def clear_spatial_mask(self, clip_id: str):
|
|
184
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
185
|
+
for clip in self.data["clips"]:
|
|
186
|
+
if self._normalize_clip_id(clip["id"]) == clip_id_normalized:
|
|
187
|
+
clip.pop("spatial_mask", None)
|
|
188
|
+
self.save()
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
def get_spatial_mask(self, clip_id: str) -> Optional[List[int]]:
|
|
192
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
193
|
+
for clip in self.data["clips"]:
|
|
194
|
+
stored_id = self._normalize_clip_id(clip["id"])
|
|
195
|
+
if stored_id == clip_id_normalized:
|
|
196
|
+
return clip.get("spatial_mask")
|
|
197
|
+
stored_base = os.path.splitext(stored_id)[0]
|
|
198
|
+
clip_base = os.path.splitext(clip_id_normalized)[0]
|
|
199
|
+
if stored_base == clip_base or stored_id == clip_base or clip_id_normalized == stored_base:
|
|
200
|
+
return clip.get("spatial_mask")
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
def set_spatial_bbox(self, clip_id: str, bbox_norm: List[float]):
|
|
204
|
+
"""Set spatial bbox [x1,y1,x2,y2] normalized to [0,1] for a clip."""
|
|
205
|
+
if not bbox_norm or len(bbox_norm) != 4:
|
|
206
|
+
return
|
|
207
|
+
x1, y1, x2, y2 = [float(v) for v in bbox_norm]
|
|
208
|
+
x1 = max(0.0, min(1.0, x1))
|
|
209
|
+
y1 = max(0.0, min(1.0, y1))
|
|
210
|
+
x2 = max(0.0, min(1.0, x2))
|
|
211
|
+
y2 = max(0.0, min(1.0, y2))
|
|
212
|
+
if x2 <= x1 or y2 <= y1:
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
216
|
+
for clip in self.data["clips"]:
|
|
217
|
+
if self._normalize_clip_id(clip["id"]) == clip_id_normalized:
|
|
218
|
+
clip["spatial_bbox"] = [x1, y1, x2, y2]
|
|
219
|
+
self.save()
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
def clear_spatial_bbox(self, clip_id: str):
|
|
223
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
224
|
+
for clip in self.data["clips"]:
|
|
225
|
+
if self._normalize_clip_id(clip["id"]) == clip_id_normalized:
|
|
226
|
+
clip.pop("spatial_bbox", None)
|
|
227
|
+
self.save()
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
def get_spatial_bbox(self, clip_id: str) -> Optional[List[float]]:
|
|
231
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
232
|
+
for clip in self.data["clips"]:
|
|
233
|
+
stored_id = self._normalize_clip_id(clip["id"])
|
|
234
|
+
if stored_id == clip_id_normalized:
|
|
235
|
+
return clip.get("spatial_bbox")
|
|
236
|
+
stored_base = os.path.splitext(stored_id)[0]
|
|
237
|
+
clip_base = os.path.splitext(clip_id_normalized)[0]
|
|
238
|
+
if stored_base == clip_base or stored_id == clip_base or clip_id_normalized == stored_base:
|
|
239
|
+
return clip.get("spatial_bbox")
|
|
240
|
+
return None
|
|
241
|
+
|
|
242
|
+
def set_spatial_bbox_frames(self, clip_id: str, frame_bboxes: List):
|
|
243
|
+
"""Set per-frame bboxes for a clip.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
frame_bboxes: list of length T, each element is [x1,y1,x2,y2] or None.
|
|
247
|
+
"""
|
|
248
|
+
if not frame_bboxes:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
def _clamp(b):
|
|
252
|
+
if b is None or len(b) != 4:
|
|
253
|
+
return None
|
|
254
|
+
x1, y1, x2, y2 = [float(v) for v in b]
|
|
255
|
+
x1, y1 = max(0.0, min(1.0, x1)), max(0.0, min(1.0, y1))
|
|
256
|
+
x2, y2 = max(0.0, min(1.0, x2)), max(0.0, min(1.0, y2))
|
|
257
|
+
if x2 <= x1 or y2 <= y1:
|
|
258
|
+
return None
|
|
259
|
+
return [x1, y1, x2, y2]
|
|
260
|
+
|
|
261
|
+
cleaned = [_clamp(b) for b in frame_bboxes]
|
|
262
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
263
|
+
for clip in self.data["clips"]:
|
|
264
|
+
if self._normalize_clip_id(clip["id"]) == clip_id_normalized:
|
|
265
|
+
clip["spatial_bbox_frames"] = cleaned
|
|
266
|
+
# Also update legacy spatial_bbox to first valid frame for backward compat
|
|
267
|
+
first_valid = next((b for b in cleaned if b is not None), None)
|
|
268
|
+
if first_valid is not None:
|
|
269
|
+
clip["spatial_bbox"] = first_valid
|
|
270
|
+
self.save()
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
def get_spatial_bbox_frames(self, clip_id: str) -> Optional[List]:
|
|
274
|
+
"""Get per-frame bboxes for a clip, or None if not set.
|
|
275
|
+
|
|
276
|
+
Returns list of [x1,y1,x2,y2] or None per frame.
|
|
277
|
+
"""
|
|
278
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
279
|
+
for clip in self.data["clips"]:
|
|
280
|
+
stored_id = self._normalize_clip_id(clip["id"])
|
|
281
|
+
if stored_id == clip_id_normalized:
|
|
282
|
+
return clip.get("spatial_bbox_frames")
|
|
283
|
+
stored_base = os.path.splitext(stored_id)[0]
|
|
284
|
+
clip_base = os.path.splitext(clip_id_normalized)[0]
|
|
285
|
+
if stored_base == clip_base or stored_id == clip_base or clip_id_normalized == stored_base:
|
|
286
|
+
return clip.get("spatial_bbox_frames")
|
|
287
|
+
return None
|
|
288
|
+
|
|
289
|
+
def clear_spatial_bbox_frames(self, clip_id: str):
|
|
290
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
291
|
+
for clip in self.data["clips"]:
|
|
292
|
+
if self._normalize_clip_id(clip["id"]) == clip_id_normalized:
|
|
293
|
+
clip.pop("spatial_bbox_frames", None)
|
|
294
|
+
self.save()
|
|
295
|
+
return
|
|
296
|
+
|
|
297
|
+
def set_frame_labels(self, clip_id: str, frame_labels: List[Optional[str]], _defer_save: bool = False):
|
|
298
|
+
"""Set per-frame behavior labels for a clip.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
frame_labels: list of length T, each element is a class name or None.
|
|
302
|
+
"""
|
|
303
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
304
|
+
for clip in self.data["clips"]:
|
|
305
|
+
if self._normalize_clip_id(clip["id"]) == clip_id_normalized:
|
|
306
|
+
clip["frame_labels"] = list(frame_labels)
|
|
307
|
+
if not _defer_save:
|
|
308
|
+
self.save()
|
|
309
|
+
return
|
|
310
|
+
|
|
311
|
+
def get_frame_labels(self, clip_id: str) -> Optional[List[Optional[str]]]:
|
|
312
|
+
"""Get per-frame behavior labels for a clip, or None if not set."""
|
|
313
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
314
|
+
for clip in self.data["clips"]:
|
|
315
|
+
stored_id = self._normalize_clip_id(clip["id"])
|
|
316
|
+
if stored_id == clip_id_normalized:
|
|
317
|
+
return clip.get("frame_labels")
|
|
318
|
+
stored_base = os.path.splitext(stored_id)[0]
|
|
319
|
+
clip_base = os.path.splitext(clip_id_normalized)[0]
|
|
320
|
+
if stored_base == clip_base or stored_id == clip_base or clip_id_normalized == stored_base:
|
|
321
|
+
return clip.get("frame_labels")
|
|
322
|
+
return None
|
|
323
|
+
|
|
324
|
+
def clear_frame_labels(self, clip_id: str):
|
|
325
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
326
|
+
for clip in self.data["clips"]:
|
|
327
|
+
if self._normalize_clip_id(clip["id"]) == clip_id_normalized:
|
|
328
|
+
clip.pop("frame_labels", None)
|
|
329
|
+
self.save()
|
|
330
|
+
return
|
|
331
|
+
|
|
332
|
+
def get_clip_label(self, clip_id: str) -> Optional[str]:
|
|
333
|
+
"""Returns None if not labeled. Handles extension mismatches between clip_id and stored annotation."""
|
|
334
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
335
|
+
|
|
336
|
+
for clip in self.data["clips"]:
|
|
337
|
+
stored_id = self._normalize_clip_id(clip["id"])
|
|
338
|
+
if stored_id == clip_id_normalized:
|
|
339
|
+
return clip["label"]
|
|
340
|
+
|
|
341
|
+
stored_base, stored_ext = os.path.splitext(stored_id)
|
|
342
|
+
clip_base, clip_ext = os.path.splitext(clip_id_normalized)
|
|
343
|
+
|
|
344
|
+
if stored_base == clip_base or stored_id == clip_base or clip_id_normalized == stored_base:
|
|
345
|
+
return clip["label"]
|
|
346
|
+
|
|
347
|
+
return None
|
|
348
|
+
|
|
349
|
+
def get_all_clips(self) -> List[Dict[str, Any]]:
|
|
350
|
+
return self.data["clips"].copy()
|
|
351
|
+
|
|
352
|
+
def get_labeled_clips(self) -> List[Dict[str, Any]]:
|
|
353
|
+
return [c for c in self.data["clips"] if c.get("label")]
|
|
354
|
+
|
|
355
|
+
def get_unlabeled_clips(self, all_clip_paths: List[str]) -> List[str]:
|
|
356
|
+
"""Get list of clip paths that don't have labels."""
|
|
357
|
+
labeled_ids = {c["id"] for c in self.data["clips"] if c.get("label")}
|
|
358
|
+
labeled_bases = {os.path.splitext(cid)[0] for cid in labeled_ids}
|
|
359
|
+
|
|
360
|
+
unlabeled = []
|
|
361
|
+
for cp in all_clip_paths:
|
|
362
|
+
cp_normalized = cp.replace('\\', '/')
|
|
363
|
+
if cp_normalized not in labeled_ids:
|
|
364
|
+
cp_base = os.path.splitext(cp_normalized)[0]
|
|
365
|
+
if cp_base not in labeled_bases:
|
|
366
|
+
unlabeled.append(cp)
|
|
367
|
+
|
|
368
|
+
return unlabeled
|
|
369
|
+
|
|
370
|
+
def remove_clip(self, clip_id: str):
|
|
371
|
+
clip_id_normalized = self._normalize_clip_id(clip_id)
|
|
372
|
+
self.data["clips"] = [
|
|
373
|
+
c for c in self.data["clips"]
|
|
374
|
+
if self._normalize_clip_id(c["id"]) != clip_id_normalized
|
|
375
|
+
]
|
|
376
|
+
self.save()
|
|
377
|
+
|
|
378
|
+
def get_clip_count_by_label(self) -> Dict[str, int]:
|
|
379
|
+
"""Get count of clips per label (counts each label in multi-label clips)."""
|
|
380
|
+
counts = {}
|
|
381
|
+
for clip in self.data["clips"]:
|
|
382
|
+
labels = clip.get("labels")
|
|
383
|
+
if isinstance(labels, list) and labels:
|
|
384
|
+
for lbl in labels:
|
|
385
|
+
counts[lbl] = counts.get(lbl, 0) + 1
|
|
386
|
+
else:
|
|
387
|
+
label = clip.get("label", "unlabeled")
|
|
388
|
+
counts[label] = counts.get(label, 0) + 1
|
|
389
|
+
return counts
|
|
390
|
+
|
|
391
|
+
def get_multilabel_stats(self) -> dict:
|
|
392
|
+
"""Return per-label exclusive/multi-class counts and combo frequencies.
|
|
393
|
+
|
|
394
|
+
Returns dict with:
|
|
395
|
+
exclusive: {label: count} — clips where this is the only label
|
|
396
|
+
shared: {label: count} — clips where this label co-occurs with others
|
|
397
|
+
combos: {(sorted tuple of labels): count}
|
|
398
|
+
"""
|
|
399
|
+
exclusive: Dict[str, int] = {}
|
|
400
|
+
shared: Dict[str, int] = {}
|
|
401
|
+
combos: Dict[tuple, int] = {}
|
|
402
|
+
for clip in self.data["clips"]:
|
|
403
|
+
labels = clip.get("labels")
|
|
404
|
+
if isinstance(labels, list) and labels:
|
|
405
|
+
lbl_list = list(labels)
|
|
406
|
+
else:
|
|
407
|
+
lbl_list = [clip.get("label", "unlabeled")]
|
|
408
|
+
if len(lbl_list) == 1:
|
|
409
|
+
exclusive[lbl_list[0]] = exclusive.get(lbl_list[0], 0) + 1
|
|
410
|
+
else:
|
|
411
|
+
combo_key = tuple(sorted(lbl_list))
|
|
412
|
+
combos[combo_key] = combos.get(combo_key, 0) + 1
|
|
413
|
+
for lbl in lbl_list:
|
|
414
|
+
shared[lbl] = shared.get(lbl, 0) + 1
|
|
415
|
+
return {"exclusive": exclusive, "shared": shared, "combos": combos}
|
|
416
|
+
|
|
417
|
+
def clear_all_clips(self):
|
|
418
|
+
self.data["clips"] = []
|
|
419
|
+
self.save()
|
|
420
|
+
|