singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,418 @@
1
+ from PyQt6.QtWidgets import (
2
+ QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QListWidget,
3
+ QListWidgetItem, QLabel, QInputDialog, QMessageBox, QCheckBox, QFileDialog
4
+ )
5
+ from PyQt6.QtCore import Qt
6
+ import json
7
+ import logging
8
+ import os
9
+ import copy
10
+ import yaml
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class TrainingProfileDialog(QDialog):
15
+ """Dialog to manage training profiles and select them for batch training."""
16
+
17
+ def __init__(self, parent=None, profiles_file="training_profiles.json"):
18
+ super().__init__(parent)
19
+ self.setWindowTitle("Training Profiles")
20
+ self.resize(450, 600)
21
+ self.parent_widget = parent
22
+ self.profiles_file = profiles_file
23
+ self.profiles = self._load_profiles()
24
+ self._setup_ui()
25
+
26
+ def _load_profiles(self):
27
+ """Load profiles from JSON file."""
28
+ if os.path.exists(self.profiles_file):
29
+ try:
30
+ with open(self.profiles_file, 'r', encoding='utf-8') as f:
31
+ return json.load(f)
32
+ except Exception as e:
33
+ logger.error("Failed to load training profiles from %s: %s", self.profiles_file, e)
34
+ return {}
35
+ return {}
36
+
37
+ def _save_profiles(self):
38
+ """Save profiles to JSON file."""
39
+ try:
40
+ with open(self.profiles_file, 'w', encoding='utf-8') as f:
41
+ json.dump(self.profiles, f, indent=2)
42
+ except Exception as e:
43
+ QMessageBox.warning(self, "Error", f"Failed to save profiles: {e}")
44
+
45
+ def reload_profiles(self, profiles_file=None):
46
+ """Reload profiles from disk and refresh the visible list."""
47
+ if profiles_file:
48
+ self.profiles_file = profiles_file
49
+ self.profiles = self._load_profiles()
50
+ self._refresh_list()
51
+
52
+ def showEvent(self, event):
53
+ self.reload_profiles()
54
+ super().showEvent(event)
55
+
56
+ def _setup_ui(self):
57
+ layout = QVBoxLayout()
58
+
59
+ layout.addWidget(QLabel("<b>Manage Training Profiles</b>"))
60
+ layout.addWidget(QLabel("Check multiple profiles to run them as a batch sequence."))
61
+
62
+ self.list_widget = QListWidget()
63
+ self.list_widget.itemClicked.connect(self._on_item_clicked)
64
+ layout.addWidget(self.list_widget)
65
+
66
+ btn_layout = QVBoxLayout()
67
+
68
+ # --- Group 1: Create/Update ---
69
+ create_layout = QHBoxLayout()
70
+
71
+ save_btn = QPushButton("Save New Profile")
72
+ save_btn.clicked.connect(self._save_new)
73
+ save_btn.setStyleSheet("font-weight: bold;")
74
+ save_btn.setToolTip("Save current UI settings as a new profile (prompts for name)")
75
+ create_layout.addWidget(save_btn)
76
+
77
+ self.update_btn = QPushButton("Update Selected")
78
+ self.update_btn.clicked.connect(self._update_selected)
79
+ self.update_btn.setEnabled(False)
80
+ self.update_btn.setToolTip("Overwrite the selected profile with current UI settings")
81
+ create_layout.addWidget(self.update_btn)
82
+
83
+ self.import_btn = QPushButton("Import From Experiment...")
84
+ self.import_btn.clicked.connect(self._import_profiles)
85
+ self.import_btn.setToolTip("Import training profiles from another experiment's config.yaml or training_profiles.json")
86
+ create_layout.addWidget(self.import_btn)
87
+
88
+ btn_layout.addLayout(create_layout)
89
+
90
+ # --- Group 2: Manage Selected ---
91
+ manage_layout = QHBoxLayout()
92
+
93
+ self.duplicate_btn = QPushButton("Duplicate")
94
+ self.duplicate_btn.clicked.connect(self._duplicate_selected)
95
+ self.duplicate_btn.setEnabled(False)
96
+ self.duplicate_btn.setToolTip("Create a copy of the selected profile")
97
+ manage_layout.addWidget(self.duplicate_btn)
98
+
99
+ self.rename_btn = QPushButton("Rename")
100
+ self.rename_btn.clicked.connect(self._rename_selected)
101
+ self.rename_btn.setEnabled(False)
102
+ self.rename_btn.setToolTip("Rename the selected profile")
103
+ manage_layout.addWidget(self.rename_btn)
104
+
105
+ self.delete_btn = QPushButton("Delete")
106
+ self.delete_btn.clicked.connect(self._delete_selected)
107
+ self.delete_btn.setEnabled(False)
108
+ self.delete_btn.setStyleSheet("color: red;")
109
+ manage_layout.addWidget(self.delete_btn)
110
+
111
+ btn_layout.addLayout(manage_layout)
112
+
113
+ # --- Group 3: Load to UI ---
114
+ self.load_btn = QPushButton("Load Profile")
115
+ self.load_btn.clicked.connect(self._load_selected)
116
+ self.load_btn.setEnabled(False)
117
+ self.load_btn.setToolTip("Apply the selected profile settings to the main Training tab")
118
+ btn_layout.addWidget(self.load_btn)
119
+
120
+ layout.addLayout(btn_layout)
121
+
122
+ layout.addSpacing(10)
123
+
124
+ close_btn = QPushButton("Close")
125
+ close_btn.clicked.connect(self.accept)
126
+ layout.addWidget(close_btn)
127
+
128
+ self.setLayout(layout)
129
+
130
+ # Refresh list now that UI elements are created
131
+ self._refresh_list()
132
+
133
+ def _refresh_list(self):
134
+ """Refresh the list widget from self.profiles."""
135
+ current_item_text = None
136
+ if self.list_widget.currentItem():
137
+ current_item_text = self.list_widget.currentItem().text()
138
+ checked_names = set()
139
+ for i in range(self.list_widget.count()):
140
+ item = self.list_widget.item(i)
141
+ if item.checkState() == Qt.CheckState.Checked:
142
+ checked_names.add(item.text())
143
+
144
+ self.list_widget.clear()
145
+ for name in sorted(self.profiles.keys()):
146
+ item = QListWidgetItem(name)
147
+ item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable)
148
+ item.setCheckState(Qt.CheckState.Checked if name in checked_names else Qt.CheckState.Unchecked)
149
+ self.list_widget.addItem(item)
150
+ if name == current_item_text:
151
+ self.list_widget.setCurrentItem(item)
152
+ item.setSelected(True)
153
+
154
+ # Re-trigger selection logic to update buttons
155
+ if self.list_widget.currentItem():
156
+ self._on_item_clicked(self.list_widget.currentItem())
157
+ else:
158
+ self.load_btn.setEnabled(False)
159
+ self.delete_btn.setEnabled(False)
160
+ self.update_btn.setEnabled(False)
161
+ self.duplicate_btn.setEnabled(False)
162
+ self.rename_btn.setEnabled(False)
163
+
164
+ def _on_item_clicked(self, item):
165
+ self.load_btn.setEnabled(True)
166
+ self.delete_btn.setEnabled(True)
167
+ self.update_btn.setEnabled(True)
168
+ self.duplicate_btn.setEnabled(True)
169
+ self.rename_btn.setEnabled(True)
170
+
171
+ def _save_new(self):
172
+ """Save current main UI settings as a new profile."""
173
+ if not self.parent_widget:
174
+ return
175
+
176
+ name, ok = QInputDialog.getText(self, "Save New Profile", "Profile Name:")
177
+ if ok and name:
178
+ name = name.strip()
179
+ if not name: return
180
+
181
+ if name in self.profiles:
182
+ QMessageBox.warning(self, "Error", f"Profile '{name}' already exists.\nUse 'Update Selected' or choose a different name.")
183
+ return
184
+
185
+ config = self.parent_widget.get_training_config()
186
+ self.profiles[name] = config
187
+ self._save_profiles()
188
+ self._refresh_list()
189
+
190
+ # Select the new item
191
+ items = self.list_widget.findItems(name, Qt.MatchFlag.MatchExactly)
192
+ if items:
193
+ self.list_widget.setCurrentItem(items[0])
194
+ self._on_item_clicked(items[0])
195
+
196
+ QMessageBox.information(self, "Saved", f"Profile '{name}' saved.")
197
+
198
+ def _default_import_dir(self):
199
+ if self.parent_widget and hasattr(self.parent_widget, "config"):
200
+ cfg = self.parent_widget.config or {}
201
+ for key in ("experiments_dir", "experiment_path", "config_path"):
202
+ path = cfg.get(key)
203
+ if path:
204
+ if os.path.isfile(path):
205
+ return os.path.dirname(path)
206
+ return path
207
+ return os.path.dirname(self.profiles_file) if self.profiles_file else os.getcwd()
208
+
209
+ def _resolve_import_profiles_path(self, selected_path):
210
+ """Resolve a user-chosen experiment/config file to training_profiles.json."""
211
+ if not selected_path:
212
+ return None
213
+ selected_path = os.path.abspath(selected_path)
214
+ if os.path.isdir(selected_path):
215
+ candidate = os.path.join(selected_path, "training_profiles.json")
216
+ return candidate if os.path.exists(candidate) else None
217
+ if os.path.basename(selected_path) == "training_profiles.json":
218
+ return selected_path
219
+ if selected_path.lower().endswith((".yaml", ".yml")):
220
+ try:
221
+ with open(selected_path, "r", encoding="utf-8") as f:
222
+ cfg = yaml.safe_load(f) or {}
223
+ profiles_path = cfg.get("training_profiles_path")
224
+ if profiles_path:
225
+ if not os.path.isabs(profiles_path):
226
+ profiles_path = os.path.join(os.path.dirname(selected_path), profiles_path)
227
+ if os.path.exists(profiles_path):
228
+ return os.path.abspath(profiles_path)
229
+ candidate = os.path.join(os.path.dirname(selected_path), "training_profiles.json")
230
+ return candidate if os.path.exists(candidate) else None
231
+ except Exception:
232
+ return None
233
+ return selected_path if os.path.exists(selected_path) else None
234
+
235
+ def _load_external_profiles(self, profiles_path):
236
+ """Load and validate external training profiles."""
237
+ with open(profiles_path, "r", encoding="utf-8") as f:
238
+ loaded = json.load(f)
239
+ if not isinstance(loaded, dict):
240
+ raise ValueError("profiles file must contain a JSON object")
241
+ valid = {str(name): cfg for name, cfg in loaded.items() if isinstance(cfg, dict)}
242
+ if not valid:
243
+ raise ValueError("no valid profile entries found")
244
+ return valid
245
+
246
+ def _import_profiles(self):
247
+ """Import profiles from another experiment into the current experiment."""
248
+ start_dir = self._default_import_dir()
249
+ selected_path, _ = QFileDialog.getOpenFileName(
250
+ self,
251
+ "Import Profiles From Experiment",
252
+ start_dir,
253
+ "Experiment/Profile Files (*.yaml *.yml *.json);;All Files (*)",
254
+ )
255
+ if not selected_path:
256
+ return
257
+
258
+ profiles_path = self._resolve_import_profiles_path(selected_path)
259
+ if not profiles_path or not os.path.exists(profiles_path):
260
+ QMessageBox.warning(
261
+ self,
262
+ "Profiles Not Found",
263
+ "Could not locate a valid 'training_profiles.json' from the selected experiment/file.",
264
+ )
265
+ return
266
+
267
+ try:
268
+ external_profiles = self._load_external_profiles(profiles_path)
269
+ except Exception as e:
270
+ QMessageBox.warning(self, "Import Failed", f"Failed to load profiles:\n{e}")
271
+ return
272
+
273
+ duplicate_names = sorted(name for name in external_profiles if name in self.profiles)
274
+ if duplicate_names:
275
+ msg = QMessageBox(self)
276
+ msg.setIcon(QMessageBox.Icon.Question)
277
+ msg.setWindowTitle("Duplicate Profile Names")
278
+ preview = ", ".join(duplicate_names[:6])
279
+ if len(duplicate_names) > 6:
280
+ preview += ", ..."
281
+ msg.setText(
282
+ "Some imported profile names already exist in this experiment.\n\n"
283
+ f"Duplicates: {preview}"
284
+ )
285
+ overwrite_btn = msg.addButton("Overwrite Duplicates", QMessageBox.ButtonRole.AcceptRole)
286
+ rename_btn = msg.addButton("Keep Both (rename imported)", QMessageBox.ButtonRole.ActionRole)
287
+ msg.addButton("Cancel", QMessageBox.ButtonRole.RejectRole)
288
+ msg.exec()
289
+ if msg.clickedButton() == overwrite_btn:
290
+ duplicate_mode = "overwrite"
291
+ elif msg.clickedButton() == rename_btn:
292
+ duplicate_mode = "rename"
293
+ else:
294
+ return
295
+ else:
296
+ duplicate_mode = "overwrite"
297
+
298
+ imported_names = []
299
+ source_tag = os.path.splitext(os.path.basename(os.path.dirname(profiles_path) or profiles_path))[0] or "imported"
300
+ for name, cfg in external_profiles.items():
301
+ target_name = name
302
+ if duplicate_mode == "rename" and target_name in self.profiles:
303
+ suffix = 1
304
+ while True:
305
+ candidate = f"{name} ({source_tag})" if suffix == 1 else f"{name} ({source_tag} {suffix})"
306
+ if candidate not in self.profiles:
307
+ target_name = candidate
308
+ break
309
+ suffix += 1
310
+ self.profiles[target_name] = copy.deepcopy(cfg)
311
+ imported_names.append(target_name)
312
+
313
+ self._save_profiles()
314
+ self._refresh_list()
315
+
316
+ if imported_names:
317
+ items = self.list_widget.findItems(imported_names[-1], Qt.MatchFlag.MatchExactly)
318
+ if items:
319
+ self.list_widget.setCurrentItem(items[0])
320
+ self._on_item_clicked(items[0])
321
+
322
+ QMessageBox.information(
323
+ self,
324
+ "Profiles Imported",
325
+ f"Imported {len(imported_names)} profile(s) from:\n{profiles_path}",
326
+ )
327
+
328
+ def _update_selected(self):
329
+ """Update the selected profile with current UI settings."""
330
+ item = self.list_widget.currentItem()
331
+ if not item: return
332
+ name = item.text()
333
+
334
+ if QMessageBox.question(self, "Update Profile", f"Overwrite profile '{name}' with current settings from the UI?") == QMessageBox.StandardButton.Yes:
335
+ config = self.parent_widget.get_training_config()
336
+ self.profiles[name] = config
337
+ self._save_profiles()
338
+ QMessageBox.information(self, "Updated", f"Profile '{name}' updated.")
339
+
340
+ def _duplicate_selected(self):
341
+ """Duplicate the selected profile."""
342
+ item = self.list_widget.currentItem()
343
+ if not item: return
344
+ name = item.text()
345
+
346
+ new_name, ok = QInputDialog.getText(self, "Duplicate Profile", "New Profile Name:", text=f"Copy of {name}")
347
+ if ok and new_name:
348
+ new_name = new_name.strip()
349
+ if not new_name: return
350
+ if new_name in self.profiles:
351
+ QMessageBox.warning(self, "Error", f"Profile '{new_name}' already exists.")
352
+ return
353
+
354
+ self.profiles[new_name] = self.profiles[name].copy()
355
+ self._save_profiles()
356
+ self._refresh_list()
357
+
358
+ def _rename_selected(self):
359
+ """Rename the selected profile."""
360
+ item = self.list_widget.currentItem()
361
+ if not item: return
362
+ old_name = item.text()
363
+
364
+ new_name, ok = QInputDialog.getText(self, "Rename Profile", "New Name:", text=old_name)
365
+ if ok and new_name:
366
+ new_name = new_name.strip()
367
+ if not new_name or new_name == old_name: return
368
+
369
+ if new_name in self.profiles:
370
+ QMessageBox.warning(self, "Error", f"Profile '{new_name}' already exists.")
371
+ return
372
+
373
+ # Preserve order/data by popping and setting
374
+ config = self.profiles.pop(old_name)
375
+ self.profiles[new_name] = config
376
+ self._save_profiles()
377
+ self._refresh_list()
378
+
379
+ # Select the renamed item
380
+ items = self.list_widget.findItems(new_name, Qt.MatchFlag.MatchExactly)
381
+ if items:
382
+ self.list_widget.setCurrentItem(items[0])
383
+ self._on_item_clicked(items[0])
384
+
385
+ def _load_selected(self):
386
+ """Load selected profile settings into main UI."""
387
+ item = self.list_widget.currentItem()
388
+ if not item:
389
+ return
390
+ name = item.text()
391
+ if name in self.profiles:
392
+ try:
393
+ self.parent_widget.apply_training_config(self.profiles[name])
394
+ QMessageBox.information(self, "Loaded", f"Loaded profile: {name}\nSettings applied to UI.")
395
+ except Exception as e:
396
+ QMessageBox.warning(self, "Error", f"Failed to apply profile: {e}")
397
+
398
+ def _delete_selected(self):
399
+ item = self.list_widget.currentItem()
400
+ if not item:
401
+ return
402
+ name = item.text()
403
+ if QMessageBox.question(self, "Delete", f"Delete profile '{name}'?") == QMessageBox.StandardButton.Yes:
404
+ del self.profiles[name]
405
+ self._save_profiles()
406
+ self._refresh_list()
407
+
408
+ def get_selected_profiles_for_batch(self):
409
+ """Return list of (name, config) for checked items."""
410
+ selected = []
411
+ for i in range(self.list_widget.count()):
412
+ item = self.list_widget.item(i)
413
+ if item.checkState() == Qt.CheckState.Checked:
414
+ name = item.text()
415
+ if name in self.profiles:
416
+ selected.append((name, self.profiles[name]))
417
+ return selected
418
+