singlebehaviorlab 2.3.1__tar.gz → 2.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. {singlebehaviorlab-2.3.1/singlebehaviorlab.egg-info → singlebehaviorlab-2.3.2}/PKG-INFO +1 -1
  2. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/pyproject.toml +1 -1
  3. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/__init__.py +2 -1
  4. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/clustering.py +25 -3
  5. singlebehaviorlab-2.3.2/singlebehaviorlab/backend/contrastive.py +202 -0
  6. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/registration.py +9 -1
  7. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/cli.py +6 -0
  8. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/clustering_widget.py +115 -21
  9. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/registration_widget.py +49 -50
  10. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2/singlebehaviorlab.egg-info}/PKG-INFO +1 -1
  11. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab.egg-info/SOURCES.txt +1 -0
  12. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/LICENSE +0 -0
  13. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/README.md +0 -0
  14. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/setup.cfg +0 -0
  15. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/__main__.py +0 -0
  16. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/_paths.py +0 -0
  17. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/__init__.py +0 -0
  18. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/augmentations.py +0 -0
  19. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/data_store.py +0 -0
  20. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/inference.py +0 -0
  21. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/model.py +0 -0
  22. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/segmentation.py +0 -0
  23. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/segments.py +0 -0
  24. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/train.py +0 -0
  25. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/training_runner.py +0 -0
  26. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/uncertainty.py +0 -0
  27. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/video_processor.py +0 -0
  28. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/backend/video_utils.py +0 -0
  29. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/config.py +0 -0
  30. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/data/config/config.yaml +0 -0
  31. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/data/training_profiles.json +0 -0
  32. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/demo.py +0 -0
  33. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/__init__.py +0 -0
  34. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/analysis_widget.py +0 -0
  35. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/attention_export.py +0 -0
  36. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/clip_extraction_widget.py +0 -0
  37. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/inference_popups.py +0 -0
  38. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/inference_widget.py +0 -0
  39. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/inference_worker.py +0 -0
  40. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/interactive_timeline.py +0 -0
  41. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/labeling_widget.py +0 -0
  42. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/main_window.py +0 -0
  43. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/metadata_management_widget.py +0 -0
  44. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/motion_tracking.py +0 -0
  45. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/overlay_export.py +0 -0
  46. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/plot_integration.py +0 -0
  47. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/qt_helpers.py +0 -0
  48. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/review_widget.py +0 -0
  49. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/segmentation_tracking_widget.py +0 -0
  50. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/tab_tutorial_dialog.py +0 -0
  51. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/timeline_themes.py +0 -0
  52. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/training_profiles.py +0 -0
  53. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/training_widget.py +0 -0
  54. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/gui/video_utils.py +0 -0
  55. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/licenses/SAM2-LICENSE +0 -0
  56. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab/licenses/VideoPrism-LICENSE +0 -0
  57. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab.egg-info/dependency_links.txt +0 -0
  58. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab.egg-info/entry_points.txt +0 -0
  59. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab.egg-info/requires.txt +0 -0
  60. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/singlebehaviorlab.egg-info/top_level.txt +0 -0
  61. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/tests/test_clustering_smoke.py +0 -0
  62. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/tests/test_config.py +0 -0
  63. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/tests/test_motion_tracking.py +0 -0
  64. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/tests/test_paths.py +0 -0
  65. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/tests/test_sam2_smoke.py +0 -0
  66. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/tests/test_segments.py +0 -0
  67. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/__init__.py +0 -0
  68. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/automatic_mask_generator.py +0 -0
  69. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/benchmark.py +0 -0
  70. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/build_sam.py +0 -0
  71. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/configs/sam2/sam2_hiera_b+.yaml +0 -0
  72. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/configs/sam2/sam2_hiera_l.yaml +0 -0
  73. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/configs/sam2/sam2_hiera_s.yaml +0 -0
  74. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/configs/sam2/sam2_hiera_t.yaml +0 -0
  75. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +0 -0
  76. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +0 -0
  77. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +0 -0
  78. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +0 -0
  79. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +0 -0
  80. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/__init__.py +0 -0
  81. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/backbones/__init__.py +0 -0
  82. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/backbones/hieradet.py +0 -0
  83. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/backbones/image_encoder.py +0 -0
  84. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/backbones/utils.py +0 -0
  85. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/memory_attention.py +0 -0
  86. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/memory_encoder.py +0 -0
  87. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/position_encoding.py +0 -0
  88. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/sam/__init__.py +0 -0
  89. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/sam/mask_decoder.py +0 -0
  90. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/sam/prompt_encoder.py +0 -0
  91. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/sam/transformer.py +0 -0
  92. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/sam2_base.py +0 -0
  93. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/modeling/sam2_utils.py +0 -0
  94. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/sam2_hiera_b+.yaml +0 -0
  95. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/sam2_hiera_l.yaml +0 -0
  96. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/sam2_hiera_s.yaml +0 -0
  97. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/sam2_hiera_t.yaml +0 -0
  98. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/sam2_image_predictor.py +0 -0
  99. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/sam2_video_predictor.py +0 -0
  100. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/sam2_video_predictor_legacy.py +0 -0
  101. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/utils/__init__.py +0 -0
  102. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/utils/amg.py +0 -0
  103. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/utils/misc.py +0 -0
  104. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/sam2_backend/sam2/utils/transforms.py +0 -0
  105. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/videoprism_backend/videoprism/__init__.py +0 -0
  106. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/videoprism_backend/videoprism/encoders.py +0 -0
  107. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/videoprism_backend/videoprism/layers.py +0 -0
  108. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/videoprism_backend/videoprism/models.py +0 -0
  109. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/videoprism_backend/videoprism/tokenizers.py +0 -0
  110. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.2}/third_party/videoprism_backend/videoprism/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: singlebehaviorlab
3
- Version: 2.3.1
3
+ Version: 2.3.2
4
4
  Summary: Semi-automated behavioral video annotation, training, and analysis tool
5
5
  Author: Almir Aljovic
6
6
  Maintainer: Almir Aljovic
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "singlebehaviorlab"
7
- version = "2.3.1"
7
+ version = "2.3.2"
8
8
  description = "Semi-automated behavioral video annotation, training, and analysis tool"
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -19,7 +19,7 @@ or videoprism. Each symbol triggers its underlying backend module only on
19
19
  first access.
20
20
  """
21
21
 
22
- __version__ = "2.3.1"
22
+ __version__ = "2.3.2"
23
23
  __author__ = "Almir Aljovic"
24
24
 
25
25
  # Mapping of public name → (backend module, attribute name).
@@ -35,6 +35,7 @@ _PUBLIC_API = {
35
35
  "infer": ("singlebehaviorlab.backend.inference", "run_inference_on_video"),
36
36
  "train": ("singlebehaviorlab.backend.training_runner", "run_training_session"),
37
37
  "load_config": ("singlebehaviorlab.config", "load_config"),
38
+ "learn_behavior_features": ("singlebehaviorlab.backend.contrastive", "learn_behavior_features"),
38
39
  "load_demo": ("singlebehaviorlab.demo", "load_demo"),
39
40
  "DEMOS": ("singlebehaviorlab.demo", "DEMOS"),
40
41
  }
@@ -28,7 +28,8 @@ class ClusteringParams:
28
28
  n_components: int = 2
29
29
  n_neighbors: int = 15
30
30
  min_dist: float = 0.1
31
- normalization: str = "standard" # standard | minmax | l2 | none
31
+ normalization: str = "standard"
32
+ subtract_video_mean: bool = False
32
33
  leiden_resolution: float = 1.0
33
34
  leiden_k: int = 15
34
35
  min_cluster_size: int = 10
@@ -157,8 +158,29 @@ def run_clustering(
157
158
  matrix_df, metadata_df = _load_matrix_metadata(matrix_path_str, metadata_path_str)
158
159
  _log(f"Matrix shape: {matrix_df.shape[0]} features × {matrix_df.shape[1]} samples")
159
160
 
160
- processed = _normalize(matrix_df.T, params.normalization)
161
- _log(f"Processed shape: {processed.shape} (samples × features); normalization={params.normalization}")
161
+ X = matrix_df.T
162
+ X = X.replace([np.inf, -np.inf], np.nan).fillna(0.0)
163
+
164
+ if params.subtract_video_mean and metadata_df is not None:
165
+ group_col = None
166
+ for col in ("group", "video_id"):
167
+ if col in metadata_df.columns:
168
+ group_col = col
169
+ break
170
+ snippet_col = "snippet" if "snippet" in metadata_df.columns else None
171
+ if group_col and snippet_col:
172
+ for grp in metadata_df[group_col].unique():
173
+ grp_snippets = metadata_df.loc[metadata_df[group_col] == grp, snippet_col].values
174
+ mask = X.index.isin(grp_snippets)
175
+ if mask.sum() > 1:
176
+ X.loc[mask] -= X.loc[mask].mean(axis=0)
177
+ _log("Applied per-video mean subtraction")
178
+
179
+
180
+ processed = _normalize(X, params.normalization)
181
+
182
+
183
+ _log(f"Processed shape: {processed.shape} (samples × features)")
162
184
 
163
185
  _log(
164
186
  f"Running UMAP (n_neighbors={params.n_neighbors}, "
@@ -0,0 +1,202 @@
1
+ """Temporal contrastive projection for behavior-focused embeddings.
2
+
3
+ Trains a lightweight MLP on pre-computed VideoPrism embeddings using
4
+ temporal proximity as the supervision signal: clips close in time within
5
+ the same video should map nearby; clips far apart should map far away.
6
+ The projected embeddings suppress static visual factors (lighting,
7
+ background, camera) and amplify behavioral dynamics.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import os
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Optional
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ __all__ = ["learn_behavior_features"]
23
+
24
+ _DEFAULT_DIM = 128
25
+ _DEFAULT_EPOCHS = 30
26
+ _DEFAULT_LR = 3e-4
27
+ _POSITIVE_WINDOW = 5
28
+ _TEMPERATURE = 0.07
29
+
30
+
31
+ class _ProjectionHead(nn.Module):
32
+ def __init__(self, in_dim: int, out_dim: int):
33
+ super().__init__()
34
+ hidden = max(out_dim, in_dim // 2)
35
+ self.net = nn.Sequential(
36
+ nn.Linear(in_dim, hidden),
37
+ nn.ReLU(),
38
+ nn.Linear(hidden, out_dim),
39
+ )
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ return F.normalize(self.net(x), dim=-1)
43
+
44
+
45
+ def _build_pairs(
46
+ metadata: pd.DataFrame,
47
+ n_samples: int,
48
+ positive_window: int,
49
+ rng: np.random.Generator,
50
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
51
+ group_col = None
52
+ for col in ("group", "video_id"):
53
+ if col in metadata.columns:
54
+ group_col = col
55
+ break
56
+ snippet_col = "snippet" if "snippet" in metadata.columns else None
57
+ if not group_col or not snippet_col:
58
+ indices = np.arange(len(metadata))
59
+ rng.shuffle(indices)
60
+ anchors = indices[:n_samples]
61
+ positives = np.clip(anchors + rng.integers(-positive_window, positive_window + 1, size=n_samples), 0, len(metadata) - 1)
62
+ negatives = rng.integers(0, len(metadata), size=n_samples)
63
+ return anchors, positives, negatives
64
+
65
+ groups = metadata[group_col].values
66
+ unique_groups = np.unique(groups)
67
+ group_indices: dict[Any, np.ndarray] = {}
68
+ for g in unique_groups:
69
+ group_indices[g] = np.where(groups == g)[0]
70
+
71
+ anchors = []
72
+ positives = []
73
+ negatives = []
74
+ per_group = max(1, n_samples // len(unique_groups))
75
+
76
+ for g in unique_groups:
77
+ idx = group_indices[g]
78
+ if len(idx) < 2:
79
+ continue
80
+ a = rng.choice(idx, size=min(per_group, len(idx)), replace=len(idx) < per_group)
81
+ for ai in a:
82
+ pos_in_group = np.where(idx == ai)[0][0]
83
+ lo = max(0, pos_in_group - positive_window)
84
+ hi = min(len(idx), pos_in_group + positive_window + 1)
85
+ candidates = idx[lo:hi]
86
+ candidates = candidates[candidates != ai]
87
+ if len(candidates) == 0:
88
+ continue
89
+ pi = rng.choice(candidates)
90
+
91
+ other_groups = [og for og in unique_groups if og != g]
92
+ if other_groups:
93
+ ng = rng.choice(other_groups)
94
+ ni = rng.choice(group_indices[ng])
95
+ else:
96
+ far_lo = max(0, pos_in_group - 3 * positive_window)
97
+ far_hi = min(len(idx), pos_in_group + 3 * positive_window + 1)
98
+ far_candidates = np.setdiff1d(idx, idx[far_lo:far_hi])
99
+ if len(far_candidates) == 0:
100
+ far_candidates = idx
101
+ ni = rng.choice(far_candidates)
102
+
103
+ anchors.append(ai)
104
+ positives.append(pi)
105
+ negatives.append(ni)
106
+
107
+ return np.array(anchors), np.array(positives), np.array(negatives)
108
+
109
+
110
+ def _info_nce_loss(anchor, positive, negative, temperature):
111
+ pos_sim = (anchor * positive).sum(dim=-1) / temperature
112
+ neg_sim = (anchor * negative).sum(dim=-1) / temperature
113
+ logits = torch.stack([pos_sim, neg_sim], dim=-1)
114
+ labels = torch.zeros(len(anchor), dtype=torch.long, device=anchor.device)
115
+ return F.cross_entropy(logits, labels)
116
+
117
+
118
+ def learn_behavior_features(
119
+ matrix_path: str | os.PathLike[str],
120
+ output_path: str | os.PathLike[str],
121
+ *,
122
+ metadata_path: Optional[str | os.PathLike[str]] = None,
123
+ out_dim: int = _DEFAULT_DIM,
124
+ epochs: int = _DEFAULT_EPOCHS,
125
+ lr: float = _DEFAULT_LR,
126
+ positive_window: int = _POSITIVE_WINDOW,
127
+ temperature: float = _TEMPERATURE,
128
+ log_fn: Optional[Callable[[str], None]] = None,
129
+ ) -> dict[str, str]:
130
+ """Train a contrastive projection and write the projected embedding matrix.
131
+
132
+ Returns dict with ``matrix`` and ``metadata`` output paths.
133
+ """
134
+ from singlebehaviorlab.backend.clustering import _load_matrix_metadata
135
+
136
+ matrix_path = str(Path(matrix_path).expanduser().resolve())
137
+ output_path_obj = Path(output_path).expanduser().resolve()
138
+ output_path_obj.parent.mkdir(parents=True, exist_ok=True)
139
+
140
+ metadata_path_str = str(Path(metadata_path).expanduser().resolve()) if metadata_path else None
141
+
142
+ def _log(msg: str) -> None:
143
+ if log_fn:
144
+ log_fn(msg)
145
+
146
+ matrix_df, metadata_df = _load_matrix_metadata(matrix_path, metadata_path_str)
147
+ X = matrix_df.T
148
+ embeddings = X.values.astype(np.float32)
149
+ n_samples, in_dim = embeddings.shape
150
+ _log(f"Loaded {n_samples} embeddings ({in_dim}-dim)")
151
+
152
+ if metadata_df is None:
153
+ metadata_df = pd.DataFrame({"snippet": X.index, "group": "video_0"})
154
+
155
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ model = _ProjectionHead(in_dim, out_dim).to(device)
157
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
158
+ all_emb = torch.from_numpy(embeddings).to(device)
159
+
160
+ rng = np.random.default_rng(42)
161
+ pairs_per_epoch = max(1024, min(n_samples * 4, 65536))
162
+
163
+ _log(f"Training projection head ({in_dim} → {out_dim}) for {epochs} epochs on {device}")
164
+ for epoch in range(epochs):
165
+ anchors, positives, negatives = _build_pairs(metadata_df, pairs_per_epoch, positive_window, rng)
166
+ if len(anchors) == 0:
167
+ _log("No valid pairs found — check metadata has group/video_id column")
168
+ break
169
+ a_emb = model(all_emb[anchors])
170
+ p_emb = model(all_emb[positives])
171
+ n_emb = model(all_emb[negatives])
172
+ loss = _info_nce_loss(a_emb, p_emb, n_emb, temperature)
173
+ optimizer.zero_grad()
174
+ loss.backward()
175
+ optimizer.step()
176
+ if (epoch + 1) % 10 == 0 or epoch == 0:
177
+ _log(f" epoch {epoch + 1}/{epochs} loss={loss.item():.4f}")
178
+
179
+ model.eval()
180
+ with torch.no_grad():
181
+ projected = model(all_emb).cpu().numpy()
182
+ _log(f"Projected embeddings: {projected.shape}")
183
+
184
+ snippet_ids = np.array(X.index.tolist())
185
+ feature_names = np.array([f"behavior_feat_{i}" for i in range(out_dim)])
186
+
187
+ out_matrix = str(output_path_obj)
188
+ if out_matrix.endswith("_matrix.npz"):
189
+ out_metadata = out_matrix.replace("_matrix.npz", "_metadata.npz")
190
+ elif out_matrix.endswith(".npz"):
191
+ out_metadata = out_matrix[:-4] + "_metadata.npz"
192
+ else:
193
+ out_metadata = out_matrix + "_metadata.npz"
194
+
195
+ np.savez_compressed(out_matrix, matrix=projected.T, feature_names=feature_names, snippet_ids=snippet_ids)
196
+ _log(f"Wrote projected matrix: {out_matrix}")
197
+
198
+ if metadata_df is not None:
199
+ np.savez_compressed(out_metadata, metadata=metadata_df.values, columns=np.array(metadata_df.columns))
200
+ _log(f"Wrote metadata: {out_metadata}")
201
+
202
+ return {"matrix": out_matrix, "metadata": out_metadata}
@@ -40,6 +40,7 @@ class RegistrationParams:
40
40
  clip_length_frames: int = 16
41
41
  step_frames: Optional[int] = None
42
42
  backbone_model: str = "videoprism_public_v1_base"
43
+ flip_invariant: bool = False
43
44
  experiment_name: Optional[str] = None
44
45
 
45
46
  @property
@@ -71,6 +72,7 @@ def _extract_embedding(
71
72
  backbone: VideoPrismBackbone,
72
73
  frames: np.ndarray,
73
74
  target_size: int,
75
+ flip_invariant: bool = False,
74
76
  ) -> Optional[np.ndarray]:
75
77
  try:
76
78
  resized = np.array([cv2.resize(f, (target_size, target_size)) for f in frames])
@@ -79,6 +81,12 @@ def _extract_embedding(
79
81
  with torch.no_grad():
80
82
  tokens = backbone(tensor)
81
83
  embedding = tokens.mean(dim=1).squeeze(0).cpu().numpy()
84
+ if flip_invariant:
85
+ embs = [embedding]
86
+ for dims in [[-1], [-2], [-1, -2]]:
87
+ t_flip = torch.flip(tensor, dims=dims)
88
+ embs.append(backbone(t_flip).mean(dim=1).squeeze(0).cpu().numpy())
89
+ embedding = np.mean(embs, axis=0)
82
90
  return embedding.astype(np.float32)
83
91
  except Exception:
84
92
  return None
@@ -178,7 +186,7 @@ def run_registration(
178
186
  if frames is None or len(frames) == 0:
179
187
  _log(f"Skipping {os.path.basename(clip_path)}: no frames")
180
188
  continue
181
- embedding = _extract_embedding(backbone, frames, params.target_size)
189
+ embedding = _extract_embedding(backbone, frames, params.target_size, params.flip_invariant)
182
190
  del frames
183
191
  if embedding is None:
184
192
  _log(f"Skipping {os.path.basename(clip_path)}: embedding failed")
@@ -178,6 +178,10 @@ def _build_parser() -> argparse.ArgumentParser:
178
178
  "--no-clahe", dest="clahe", action="store_false", default=None,
179
179
  help="Disable CLAHE contrast normalization.",
180
180
  )
181
+ register_parser.add_argument(
182
+ "--flip-invariant", action="store_true",
183
+ help="Average original + horizontally flipped embeddings to remove facing-direction bias. 2x extraction time.",
184
+ )
181
185
  _add_common_runtime_flags(register_parser)
182
186
 
183
187
  segment_parser = subparsers.add_parser(
@@ -347,6 +351,8 @@ def cmd_register(args: argparse.Namespace) -> int:
347
351
  params.target_fps = int(args.target_fps)
348
352
  if args.clahe is False:
349
353
  params.normalization_method = "None"
354
+ if args.flip_invariant:
355
+ params.flip_invariant = True
350
356
 
351
357
  bar = {"pbar": None}
352
358
 
@@ -27,7 +27,7 @@ from PyQt6.QtWidgets import (
27
27
  QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel,
28
28
  QComboBox, QSlider, QCheckBox, QGroupBox, QScrollArea, QSplitter,
29
29
  QMessageBox, QListWidget, QTextEdit, QFileDialog, QProgressBar, QDialog,
30
- QSizePolicy, QDialogButtonBox
30
+ QSizePolicy, QDialogButtonBox, QApplication
31
31
  )
32
32
  from PyQt6.QtCore import Qt, QThread, pyqtSignal
33
33
  from PyQt6.QtGui import QFont
@@ -234,11 +234,29 @@ class ClusteringWidget(QWidget):
234
234
  )
235
235
  norm_row.addWidget(self.normalization_method)
236
236
  preprocess_layout.addLayout(norm_row)
237
-
237
+
238
+ self.subtract_video_mean_check = QCheckBox("Subtract per-video mean")
239
+ self.subtract_video_mean_check.setToolTip(
240
+ "Remove the average embedding of each video/group before clustering.\n"
241
+ "Reduces sensitivity to camera setup, lighting, and background\n"
242
+ "while preserving within-video behavior differences."
243
+ )
244
+ preprocess_layout.addWidget(self.subtract_video_mean_check)
245
+
246
+ self.learn_features_btn = QPushButton("Learn behavior features")
247
+ self.learn_features_btn.setToolTip(
248
+ "Train a contrastive projection on the loaded embeddings.\n"
249
+ "Clips close in time map nearby; clips far apart map far away.\n"
250
+ "Suppresses static visual factors and amplifies behavioral dynamics.\n"
251
+ "Replaces the current matrix with 128-dim projected embeddings."
252
+ )
253
+ self.learn_features_btn.clicked.connect(self._learn_behavior_features)
254
+ preprocess_layout.addWidget(self.learn_features_btn)
255
+
238
256
  self.preprocess_btn = QPushButton("Apply preprocessing")
239
257
  self.preprocess_btn.clicked.connect(self.apply_preprocessing)
240
258
  preprocess_layout.addWidget(self.preprocess_btn)
241
-
259
+
242
260
  self.preprocess_status = QLabel("Ready")
243
261
  self.preprocess_status.setWordWrap(True)
244
262
  self.preprocess_status.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum)
@@ -742,8 +760,68 @@ class ClusteringWidget(QWidget):
742
760
  except Exception as e:
743
761
  QMessageBox.critical(self, "Load Error", f"Failed to load data: {e}")
744
762
 
763
+ def _learn_behavior_features(self):
764
+ if self.matrix_data is None:
765
+ QMessageBox.warning(self, "No data", "Load a feature matrix first.")
766
+ return
767
+ try:
768
+ import tempfile
769
+ from singlebehaviorlab.backend.contrastive import learn_behavior_features
770
+
771
+ self.preprocess_status.setText("Training contrastive projection...")
772
+ self.preprocess_status.setStyleSheet("color: blue;")
773
+ QApplication.processEvents()
774
+
775
+ with tempfile.NamedTemporaryFile(suffix="_matrix.npz", delete=False) as tmp_m:
776
+ tmp_matrix = tmp_m.name
777
+ tmp_metadata = tmp_matrix.replace("_matrix.npz", "_metadata.npz")
778
+ out_matrix = tmp_matrix.replace("_matrix.npz", "_proj_matrix.npz")
779
+
780
+ snippet_ids = np.array(self.matrix_data.columns.tolist())
781
+ feature_names = np.array(self.matrix_data.index.tolist())
782
+ np.savez_compressed(tmp_matrix, matrix=self.matrix_data.values, feature_names=feature_names, snippet_ids=snippet_ids)
783
+ if self.metadata is not None:
784
+ np.savez_compressed(tmp_metadata, metadata=self.metadata.values, columns=np.array(self.metadata.columns))
785
+ else:
786
+ tmp_metadata = None
787
+
788
+ log_lines = []
789
+ result = learn_behavior_features(
790
+ tmp_matrix,
791
+ out_matrix,
792
+ metadata_path=tmp_metadata,
793
+ log_fn=lambda msg: (log_lines.append(msg), self.preprocess_status.setText(msg), QApplication.processEvents()),
794
+ )
795
+
796
+ proj = np.load(result["matrix"], allow_pickle=True)
797
+ self.matrix_data = pd.DataFrame(proj["matrix"], index=proj["feature_names"], columns=proj["snippet_ids"])
798
+
799
+ for f in [tmp_matrix, tmp_metadata, out_matrix, result.get("metadata")]:
800
+ if f and os.path.exists(f):
801
+ try:
802
+ os.unlink(f)
803
+ except Exception:
804
+ pass
805
+
806
+ X = self.matrix_data.T
807
+ X = X.replace([np.inf, -np.inf], np.nan).fillna(0.0)
808
+ from sklearn.preprocessing import StandardScaler
809
+ X_norm = StandardScaler().fit_transform(X)
810
+ self.processed_data = pd.DataFrame(X_norm, index=X.index, columns=range(X_norm.shape[1]))
811
+
812
+ n = self.matrix_data.shape[1]
813
+ self.preprocess_status.setText(
814
+ f"Behavior features: {n} clips → 128-dim (contrastive) → standardized. Ready to cluster."
815
+ )
816
+ self.preprocess_status.setStyleSheet("color: green;")
817
+ self.run_btn.setEnabled(True)
818
+
819
+ except Exception as e:
820
+ self.preprocess_status.setText(f"Feature learning failed: {e}")
821
+ self.preprocess_status.setStyleSheet("color: red;")
822
+ QMessageBox.critical(self, "Error", f"Contrastive training failed:\n{e}")
823
+
745
824
  def apply_preprocessing(self):
746
- """Apply normalization."""
747
825
  if self.matrix_data is None:
748
826
  return
749
827
 
@@ -759,28 +837,44 @@ class ClusteringWidget(QWidget):
759
837
  # sklearn expects Samples as Rows. So we transpose.
760
838
 
761
839
  X = data.T
762
-
763
- # Clean infinite/NaN
764
- X = X.replace([np.inf, -np.inf], np.nan)
765
-
766
- # Normalize
840
+ X = X.replace([np.inf, -np.inf], np.nan).fillna(0.0)
841
+
842
+ steps = []
843
+
844
+ if self.subtract_video_mean_check.isChecked() and self.metadata is not None:
845
+ group_col = None
846
+ for col in ("group", "video_id"):
847
+ if col in self.metadata.columns:
848
+ group_col = col
849
+ break
850
+ if group_col is not None:
851
+ snippet_col = "snippet" if "snippet" in self.metadata.columns else None
852
+ if snippet_col:
853
+ for grp in self.metadata[group_col].unique():
854
+ grp_snippets = self.metadata.loc[
855
+ self.metadata[group_col] == grp, snippet_col
856
+ ].values
857
+ mask = X.index.isin(grp_snippets)
858
+ if mask.sum() > 1:
859
+ X.loc[mask] -= X.loc[mask].mean(axis=0)
860
+ steps.append("video-mean-sub")
861
+
862
+
767
863
  norm_method = self.normalization_method.currentText()
768
864
  if norm_method == 'standard':
769
- scaler = StandardScaler()
770
- X_norm = scaler.fit_transform(X)
865
+ X_norm = StandardScaler().fit_transform(X)
771
866
  elif norm_method == 'minmax':
772
- scaler = MinMaxScaler()
773
- X_norm = scaler.fit_transform(X)
867
+ X_norm = MinMaxScaler().fit_transform(X)
774
868
  elif norm_method == 'l2':
775
- scaler = Normalizer(norm='l2')
776
- X_norm = scaler.fit_transform(X)
869
+ X_norm = Normalizer(norm='l2').fit_transform(X)
777
870
  else:
778
- X_norm = X
779
-
780
- # Store processed data (Samples x Features)
781
- self.processed_data = pd.DataFrame(X_norm, index=X.index, columns=X.columns)
782
-
783
- self.preprocess_status.setText(f"Normalized: {norm_method}")
871
+ X_norm = X.values if hasattr(X, 'values') else X
872
+ if norm_method != 'none':
873
+ steps.append(norm_method)
874
+
875
+ self.processed_data = pd.DataFrame(X_norm, index=X.index, columns=range(X_norm.shape[1]))
876
+
877
+ self.preprocess_status.setText(f"Preprocessed: {' → '.join(steps) or 'none'}")
784
878
  self.preprocess_status.setStyleSheet("color: green;")
785
879
 
786
880
  except Exception as e:
@@ -326,15 +326,18 @@ class EmbeddingExtractionWorker(QThread):
326
326
  error = pyqtSignal(str)
327
327
  log_message = pyqtSignal(str)
328
328
 
329
- def __init__(self, clip_paths: list, output_dir: str, experiment_name: str = None, model_name: str = 'videoprism_public_v1_base', clip_frame_ranges: dict = None, append_to_existing: bool = False):
329
+ def __init__(self, clip_paths: list, output_dir: str, experiment_name: str = None, model_name: str = 'videoprism_public_v1_base', clip_frame_ranges: dict = None, append_to_existing: bool = False, flip_invariant: bool = False, align_orientation: bool = False, mask_path: str = None):
330
330
  super().__init__()
331
- self.clip_paths = clip_paths # List of clip paths (strings)
332
- self.clip_frame_ranges = clip_frame_ranges or {} # Dict mapping clip_path -> (start_frame, end_frame)
331
+ self.clip_paths = clip_paths
332
+ self.clip_frame_ranges = clip_frame_ranges or {}
333
333
  self.output_dir = output_dir
334
334
  self.experiment_name = experiment_name
335
335
  self.model_name = model_name
336
336
  self.should_stop = False
337
337
  self.append_to_existing = append_to_existing
338
+ self.flip_invariant = flip_invariant
339
+ self.align_orientation = align_orientation
340
+ self.mask_path = mask_path
338
341
 
339
342
  def stop(self):
340
343
  self.should_stop = True
@@ -352,6 +355,8 @@ class EmbeddingExtractionWorker(QThread):
352
355
 
353
356
  embed_dim = backbone.get_embed_dim()
354
357
  self.log_message.emit(f"VideoPrism model loaded. Embedding dimension: {embed_dim}")
358
+ if self.flip_invariant:
359
+ self.log_message.emit("Flip-invariant mode: averaging 4 orientations (original, hflip, vflip, both)")
355
360
 
356
361
  feature_matrix = []
357
362
  metadata = []
@@ -372,10 +377,7 @@ class EmbeddingExtractionWorker(QThread):
372
377
  self.log_message.emit(f"Warning: Could not load frames from {clip_name}, skipping")
373
378
  continue
374
379
 
375
- # Extract embedding
376
380
  embedding = self._extract_embedding(backbone, frames)
377
-
378
- # Free frames memory immediately after use
379
381
  del frames
380
382
 
381
383
  if embedding is None:
@@ -539,25 +541,7 @@ class EmbeddingExtractionWorker(QThread):
539
541
  self.log_message.emit(f"NPZ save failed (metadata): {e}")
540
542
  npz_metadata_path = None
541
543
 
542
- # Also save Parquet as backup (faster than CSV, still readable)
543
- try:
544
- matrix_df = pd.DataFrame(feature_matrix.T, index=feature_names, columns=snippet_ids)
545
- parquet_matrix_path = os.path.join(self.output_dir, f'{base_name}_matrix.parquet')
546
- matrix_df.to_parquet(parquet_matrix_path, index=True)
547
- self.log_message.emit(f"Saved feature matrix (Parquet) to {parquet_matrix_path}")
548
- except Exception as e:
549
- self.log_message.emit(f"Parquet save failed (matrix): {e}")
550
-
551
- try:
552
- parquet_metadata_path = os.path.join(self.output_dir, f'{base_name}_metadata.parquet')
553
- metadata_df.to_parquet(parquet_metadata_path, index=False)
554
- self.log_message.emit(f"Saved metadata (Parquet) to {parquet_metadata_path}")
555
- except Exception as e:
556
- self.log_message.emit(f"Parquet save failed (metadata): {e}")
557
-
558
- # Emit NPZ paths (primary format)
559
- self.finished.emit(npz_matrix_path if npz_matrix_path else parquet_matrix_path,
560
- npz_metadata_path if npz_metadata_path else parquet_metadata_path)
544
+ self.finished.emit(npz_matrix_path, npz_metadata_path)
561
545
 
562
546
  except Exception as e:
563
547
  import traceback
@@ -584,38 +568,36 @@ class EmbeddingExtractionWorker(QThread):
584
568
  return np.array(frames) if frames else None
585
569
 
586
570
  def _extract_embedding(self, backbone: VideoPrismBackbone, frames: np.ndarray) -> np.ndarray:
587
- """Extract mean-pooled VideoPrism embedding from frames."""
588
571
  try:
589
- # Resize frames to 288x288 (VideoPrism expects this)
590
572
  target_size = 288
591
573
  processed_frames = []
592
574
  for frame in frames:
593
575
  resized = cv2.resize(frame, (target_size, target_size))
594
576
  processed_frames.append(resized)
595
577
  frames_resized = np.array(processed_frames)
596
- del processed_frames # Free list memory
597
-
598
- # Convert to PyTorch format: (T, C, H, W) and normalize to [0, 1]
599
- frames_t = np.transpose(frames_resized, (0, 3, 1, 2)) # (T, C, H, W)
600
- del frames_resized # Free numpy array
601
-
578
+ del processed_frames
579
+ frames_t = np.transpose(frames_resized, (0, 3, 1, 2))
580
+ del frames_resized
602
581
  frames_tensor = torch.from_numpy(frames_t).float() / 255.0
603
- del frames_t # Free numpy array
604
-
605
- # Add batch dimension: (1, T, C, H, W)
582
+ del frames_t
606
583
  frames_tensor = frames_tensor.unsqueeze(0)
607
584
 
608
585
  with torch.no_grad():
609
- # VideoPrism returns (B, T*N, D) where N = 16*16 = 256
610
- tokens = backbone(frames_tensor) # (1, T*256, D)
611
- del frames_tensor # Free input tensor immediately
612
-
613
- # Mean pool over all tokens to get single embedding vector
614
- embedding = tokens.mean(dim=1).squeeze(0) # (D,)
615
- del tokens # Free large token tensor
616
-
586
+ tokens = backbone(frames_tensor)
587
+ embedding = tokens.mean(dim=1).squeeze(0)
588
+ del tokens
589
+ if self.flip_invariant:
590
+ embs = [embedding.cpu().numpy()]
591
+ for dims in [[-1], [-2], [-1, -2]]:
592
+ t_flip = torch.flip(frames_tensor, dims=dims)
593
+ embs.append(backbone(t_flip).mean(dim=1).squeeze(0).cpu().numpy())
594
+ del t_flip
595
+ embedding = torch.from_numpy(np.mean(embs, axis=0))
596
+ del embs
597
+ del frames_tensor
598
+
617
599
  result = embedding.cpu().numpy()
618
- del embedding # Free GPU tensor
600
+ del embedding
619
601
  return result
620
602
 
621
603
  except Exception as e:
@@ -821,6 +803,16 @@ class RegistrationWidget(QWidget):
821
803
  self.output_dir_label = QLabel("Clips will be saved to experiment folder")
822
804
  output_layout.addWidget(self.output_dir_label)
823
805
 
806
+ self.flip_invariant_check = QCheckBox("Flip-invariant embeddings")
807
+ self.flip_invariant_check.setChecked(False)
808
+ self.flip_invariant_check.setToolTip(
809
+ "Run each clip through VideoPrism in 4 orientations (original, hflip,\n"
810
+ "vflip, both) and average the embeddings. Removes sensitivity to the\n"
811
+ "animal's facing direction and vertical orientation. 4x extraction time."
812
+ )
813
+ output_layout.addWidget(self.flip_invariant_check)
814
+
815
+
824
816
  self.append_embeddings_check = QCheckBox("Append to existing embeddings if present")
825
817
  self.append_embeddings_check.setChecked(False)
826
818
  self.append_embeddings_check.setToolTip("When enabled, if an existing behaviorome matrix/metadata is found in the experiment, new embeddings will be appended instead of creating a new file.")
@@ -1090,7 +1082,6 @@ class RegistrationWidget(QWidget):
1090
1082
  self.log_text.append(f"Output directory: {self.output_dir}")
1091
1083
  self.log_text.append(f"Created {len(output_paths)} clip(s)")
1092
1084
 
1093
- # Extract clip paths and frame ranges from tuples
1094
1085
  clip_paths_list = []
1095
1086
  self.clip_frame_ranges = {}
1096
1087
  for item in output_paths:
@@ -1099,8 +1090,8 @@ class RegistrationWidget(QWidget):
1099
1090
  clip_paths_list.append(clip_path)
1100
1091
  self.clip_frame_ranges[clip_path] = (start_frame, end_frame)
1101
1092
  else:
1102
- # Legacy: just a path string
1103
1093
  clip_paths_list.append(item)
1094
+
1104
1095
 
1105
1096
  # Group clips by video (using extracted paths)
1106
1097
  clips_by_video = {}
@@ -1265,13 +1256,21 @@ class RegistrationWidget(QWidget):
1265
1256
  experiment_name = self.config.get("experiment_name", None)
1266
1257
 
1267
1258
  # Start extraction worker with frame ranges if available
1259
+ mask_path = None
1260
+ if self.align_orientation_check.isChecked() and self.video_mask_pairs:
1261
+ mask_path = self.video_mask_pairs[0][1] if len(self.video_mask_pairs) > 0 else None
1262
+ self.log_text.append(f"Align orientation: mask_path={mask_path}, pairs={len(self.video_mask_pairs)}, frame_ranges={len(self.clip_frame_ranges)}")
1263
+
1268
1264
  self.embedding_worker = EmbeddingExtractionWorker(
1269
- clip_paths,
1270
- self.output_dir,
1265
+ clip_paths,
1266
+ self.output_dir,
1271
1267
  experiment_name=experiment_name,
1272
1268
  model_name=model_name,
1273
1269
  clip_frame_ranges=self.clip_frame_ranges if hasattr(self, 'clip_frame_ranges') else None,
1274
- append_to_existing=self.append_embeddings_check.isChecked()
1270
+ append_to_existing=self.append_embeddings_check.isChecked(),
1271
+ flip_invariant=self.flip_invariant_check.isChecked(),
1272
+ align_orientation=self.align_orientation_check.isChecked(),
1273
+ mask_path=mask_path,
1275
1274
  )
1276
1275
  self.embedding_worker.progress.connect(self._on_embedding_progress)
1277
1276
  self.embedding_worker.finished.connect(self._on_embedding_finished)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: singlebehaviorlab
3
- Version: 2.3.1
3
+ Version: 2.3.2
4
4
  Summary: Semi-automated behavioral video annotation, training, and analysis tool
5
5
  Author: Almir Aljovic
6
6
  Maintainer: Almir Aljovic
@@ -16,6 +16,7 @@ singlebehaviorlab.egg-info/top_level.txt
16
16
  singlebehaviorlab/backend/__init__.py
17
17
  singlebehaviorlab/backend/augmentations.py
18
18
  singlebehaviorlab/backend/clustering.py
19
+ singlebehaviorlab/backend/contrastive.py
19
20
  singlebehaviorlab/backend/data_store.py
20
21
  singlebehaviorlab/backend/inference.py
21
22
  singlebehaviorlab/backend/model.py