singlebehaviorlab 2.3.1__tar.gz → 2.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. {singlebehaviorlab-2.3.1/singlebehaviorlab.egg-info → singlebehaviorlab-2.3.3}/PKG-INFO +2 -3
  2. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/pyproject.toml +2 -3
  3. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/__init__.py +2 -1
  4. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/clustering.py +25 -3
  5. singlebehaviorlab-2.3.3/singlebehaviorlab/backend/contrastive.py +202 -0
  6. singlebehaviorlab-2.3.3/singlebehaviorlab/backend/embedding_refine.py +158 -0
  7. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/registration.py +9 -1
  8. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/cli.py +6 -0
  9. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/clustering_widget.py +115 -21
  10. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/inference_widget.py +40 -2
  11. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/inference_worker.py +1 -0
  12. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/registration_widget.py +40 -50
  13. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3/singlebehaviorlab.egg-info}/PKG-INFO +2 -3
  14. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab.egg-info/SOURCES.txt +2 -0
  15. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab.egg-info/requires.txt +0 -1
  16. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/videoprism_backend/videoprism/tokenizers.py +4 -6
  17. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/videoprism_backend/videoprism/utils.py +1 -10
  18. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/LICENSE +0 -0
  19. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/README.md +0 -0
  20. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/setup.cfg +0 -0
  21. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/__main__.py +0 -0
  22. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/_paths.py +0 -0
  23. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/__init__.py +0 -0
  24. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/augmentations.py +0 -0
  25. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/data_store.py +0 -0
  26. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/inference.py +0 -0
  27. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/model.py +0 -0
  28. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/segmentation.py +0 -0
  29. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/segments.py +0 -0
  30. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/train.py +0 -0
  31. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/training_runner.py +0 -0
  32. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/uncertainty.py +0 -0
  33. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/video_processor.py +0 -0
  34. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/backend/video_utils.py +0 -0
  35. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/config.py +0 -0
  36. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/data/config/config.yaml +0 -0
  37. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/data/training_profiles.json +0 -0
  38. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/demo.py +0 -0
  39. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/__init__.py +0 -0
  40. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/analysis_widget.py +0 -0
  41. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/attention_export.py +0 -0
  42. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/clip_extraction_widget.py +0 -0
  43. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/inference_popups.py +0 -0
  44. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/interactive_timeline.py +0 -0
  45. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/labeling_widget.py +0 -0
  46. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/main_window.py +0 -0
  47. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/metadata_management_widget.py +0 -0
  48. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/motion_tracking.py +0 -0
  49. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/overlay_export.py +0 -0
  50. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/plot_integration.py +0 -0
  51. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/qt_helpers.py +0 -0
  52. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/review_widget.py +0 -0
  53. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/segmentation_tracking_widget.py +0 -0
  54. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/tab_tutorial_dialog.py +0 -0
  55. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/timeline_themes.py +0 -0
  56. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/training_profiles.py +0 -0
  57. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/training_widget.py +0 -0
  58. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/gui/video_utils.py +0 -0
  59. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/licenses/SAM2-LICENSE +0 -0
  60. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab/licenses/VideoPrism-LICENSE +0 -0
  61. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab.egg-info/dependency_links.txt +0 -0
  62. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab.egg-info/entry_points.txt +0 -0
  63. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/singlebehaviorlab.egg-info/top_level.txt +0 -0
  64. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/tests/test_clustering_smoke.py +0 -0
  65. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/tests/test_config.py +0 -0
  66. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/tests/test_motion_tracking.py +0 -0
  67. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/tests/test_paths.py +0 -0
  68. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/tests/test_sam2_smoke.py +0 -0
  69. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/tests/test_segments.py +0 -0
  70. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/__init__.py +0 -0
  71. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/automatic_mask_generator.py +0 -0
  72. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/benchmark.py +0 -0
  73. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/build_sam.py +0 -0
  74. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/configs/sam2/sam2_hiera_b+.yaml +0 -0
  75. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/configs/sam2/sam2_hiera_l.yaml +0 -0
  76. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/configs/sam2/sam2_hiera_s.yaml +0 -0
  77. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/configs/sam2/sam2_hiera_t.yaml +0 -0
  78. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +0 -0
  79. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +0 -0
  80. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +0 -0
  81. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +0 -0
  82. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +0 -0
  83. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/__init__.py +0 -0
  84. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/backbones/__init__.py +0 -0
  85. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/backbones/hieradet.py +0 -0
  86. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/backbones/image_encoder.py +0 -0
  87. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/backbones/utils.py +0 -0
  88. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/memory_attention.py +0 -0
  89. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/memory_encoder.py +0 -0
  90. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/position_encoding.py +0 -0
  91. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/sam/__init__.py +0 -0
  92. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/sam/mask_decoder.py +0 -0
  93. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/sam/prompt_encoder.py +0 -0
  94. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/sam/transformer.py +0 -0
  95. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/sam2_base.py +0 -0
  96. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/modeling/sam2_utils.py +0 -0
  97. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/sam2_hiera_b+.yaml +0 -0
  98. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/sam2_hiera_l.yaml +0 -0
  99. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/sam2_hiera_s.yaml +0 -0
  100. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/sam2_hiera_t.yaml +0 -0
  101. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/sam2_image_predictor.py +0 -0
  102. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/sam2_video_predictor.py +0 -0
  103. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/sam2_video_predictor_legacy.py +0 -0
  104. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/utils/__init__.py +0 -0
  105. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/utils/amg.py +0 -0
  106. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/utils/misc.py +0 -0
  107. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/sam2_backend/sam2/utils/transforms.py +0 -0
  108. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/videoprism_backend/videoprism/__init__.py +0 -0
  109. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/videoprism_backend/videoprism/encoders.py +0 -0
  110. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/videoprism_backend/videoprism/layers.py +0 -0
  111. {singlebehaviorlab-2.3.1 → singlebehaviorlab-2.3.3}/third_party/videoprism_backend/videoprism/models.py +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: singlebehaviorlab
3
- Version: 2.3.1
4
- Summary: Semi-automated behavioral video annotation, training, and analysis tool
3
+ Version: 2.3.3
4
+ Summary: Behavioral sequencing and phenotyping with lightweight task specific adaptation
5
5
  Author: Almir Aljovic
6
6
  Maintainer: Almir Aljovic
7
7
  License: MIT License
@@ -59,7 +59,6 @@ Requires-Dist: einshape
59
59
  Requires-Dist: huggingface-hub
60
60
  Requires-Dist: sentencepiece
61
61
  Requires-Dist: absl-py
62
- Requires-Dist: tensorflow-cpu
63
62
  Provides-Extra: test
64
63
  Requires-Dist: pytest; extra == "test"
65
64
  Requires-Dist: pytest-cov; extra == "test"
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "singlebehaviorlab"
7
- version = "2.3.1"
8
- description = "Semi-automated behavioral video annotation, training, and analysis tool"
7
+ version = "2.3.3"
8
+ description = "Behavioral sequencing and phenotyping with lightweight task specific adaptation"
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
11
11
  requires-python = ">=3.10"
@@ -43,7 +43,6 @@ dependencies = [
43
43
  "huggingface-hub",
44
44
  "sentencepiece",
45
45
  "absl-py",
46
- "tensorflow-cpu",
47
46
  ]
48
47
 
49
48
  [project.urls]
@@ -19,7 +19,7 @@ or videoprism. Each symbol triggers its underlying backend module only on
19
19
  first access.
20
20
  """
21
21
 
22
- __version__ = "2.3.1"
22
+ __version__ = "2.3.2"
23
23
  __author__ = "Almir Aljovic"
24
24
 
25
25
  # Mapping of public name → (backend module, attribute name).
@@ -35,6 +35,7 @@ _PUBLIC_API = {
35
35
  "infer": ("singlebehaviorlab.backend.inference", "run_inference_on_video"),
36
36
  "train": ("singlebehaviorlab.backend.training_runner", "run_training_session"),
37
37
  "load_config": ("singlebehaviorlab.config", "load_config"),
38
+ "learn_behavior_features": ("singlebehaviorlab.backend.contrastive", "learn_behavior_features"),
38
39
  "load_demo": ("singlebehaviorlab.demo", "load_demo"),
39
40
  "DEMOS": ("singlebehaviorlab.demo", "DEMOS"),
40
41
  }
@@ -28,7 +28,8 @@ class ClusteringParams:
28
28
  n_components: int = 2
29
29
  n_neighbors: int = 15
30
30
  min_dist: float = 0.1
31
- normalization: str = "standard" # standard | minmax | l2 | none
31
+ normalization: str = "standard"
32
+ subtract_video_mean: bool = False
32
33
  leiden_resolution: float = 1.0
33
34
  leiden_k: int = 15
34
35
  min_cluster_size: int = 10
@@ -157,8 +158,29 @@ def run_clustering(
157
158
  matrix_df, metadata_df = _load_matrix_metadata(matrix_path_str, metadata_path_str)
158
159
  _log(f"Matrix shape: {matrix_df.shape[0]} features × {matrix_df.shape[1]} samples")
159
160
 
160
- processed = _normalize(matrix_df.T, params.normalization)
161
- _log(f"Processed shape: {processed.shape} (samples × features); normalization={params.normalization}")
161
+ X = matrix_df.T
162
+ X = X.replace([np.inf, -np.inf], np.nan).fillna(0.0)
163
+
164
+ if params.subtract_video_mean and metadata_df is not None:
165
+ group_col = None
166
+ for col in ("group", "video_id"):
167
+ if col in metadata_df.columns:
168
+ group_col = col
169
+ break
170
+ snippet_col = "snippet" if "snippet" in metadata_df.columns else None
171
+ if group_col and snippet_col:
172
+ for grp in metadata_df[group_col].unique():
173
+ grp_snippets = metadata_df.loc[metadata_df[group_col] == grp, snippet_col].values
174
+ mask = X.index.isin(grp_snippets)
175
+ if mask.sum() > 1:
176
+ X.loc[mask] -= X.loc[mask].mean(axis=0)
177
+ _log("Applied per-video mean subtraction")
178
+
179
+
180
+ processed = _normalize(X, params.normalization)
181
+
182
+
183
+ _log(f"Processed shape: {processed.shape} (samples × features)")
162
184
 
163
185
  _log(
164
186
  f"Running UMAP (n_neighbors={params.n_neighbors}, "
@@ -0,0 +1,202 @@
1
+ """Temporal contrastive projection for behavior-focused embeddings.
2
+
3
+ Trains a lightweight MLP on pre-computed VideoPrism embeddings using
4
+ temporal proximity as the supervision signal: clips close in time within
5
+ the same video should map nearby; clips far apart should map far away.
6
+ The projected embeddings suppress static visual factors (lighting,
7
+ background, camera) and amplify behavioral dynamics.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import os
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Optional
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ __all__ = ["learn_behavior_features"]
23
+
24
+ _DEFAULT_DIM = 128
25
+ _DEFAULT_EPOCHS = 30
26
+ _DEFAULT_LR = 3e-4
27
+ _POSITIVE_WINDOW = 5
28
+ _TEMPERATURE = 0.07
29
+
30
+
31
+ class _ProjectionHead(nn.Module):
32
+ def __init__(self, in_dim: int, out_dim: int):
33
+ super().__init__()
34
+ hidden = max(out_dim, in_dim // 2)
35
+ self.net = nn.Sequential(
36
+ nn.Linear(in_dim, hidden),
37
+ nn.ReLU(),
38
+ nn.Linear(hidden, out_dim),
39
+ )
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ return F.normalize(self.net(x), dim=-1)
43
+
44
+
45
+ def _build_pairs(
46
+ metadata: pd.DataFrame,
47
+ n_samples: int,
48
+ positive_window: int,
49
+ rng: np.random.Generator,
50
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
51
+ group_col = None
52
+ for col in ("group", "video_id"):
53
+ if col in metadata.columns:
54
+ group_col = col
55
+ break
56
+ snippet_col = "snippet" if "snippet" in metadata.columns else None
57
+ if not group_col or not snippet_col:
58
+ indices = np.arange(len(metadata))
59
+ rng.shuffle(indices)
60
+ anchors = indices[:n_samples]
61
+ positives = np.clip(anchors + rng.integers(-positive_window, positive_window + 1, size=n_samples), 0, len(metadata) - 1)
62
+ negatives = rng.integers(0, len(metadata), size=n_samples)
63
+ return anchors, positives, negatives
64
+
65
+ groups = metadata[group_col].values
66
+ unique_groups = np.unique(groups)
67
+ group_indices: dict[Any, np.ndarray] = {}
68
+ for g in unique_groups:
69
+ group_indices[g] = np.where(groups == g)[0]
70
+
71
+ anchors = []
72
+ positives = []
73
+ negatives = []
74
+ per_group = max(1, n_samples // len(unique_groups))
75
+
76
+ for g in unique_groups:
77
+ idx = group_indices[g]
78
+ if len(idx) < 2:
79
+ continue
80
+ a = rng.choice(idx, size=min(per_group, len(idx)), replace=len(idx) < per_group)
81
+ for ai in a:
82
+ pos_in_group = np.where(idx == ai)[0][0]
83
+ lo = max(0, pos_in_group - positive_window)
84
+ hi = min(len(idx), pos_in_group + positive_window + 1)
85
+ candidates = idx[lo:hi]
86
+ candidates = candidates[candidates != ai]
87
+ if len(candidates) == 0:
88
+ continue
89
+ pi = rng.choice(candidates)
90
+
91
+ other_groups = [og for og in unique_groups if og != g]
92
+ if other_groups:
93
+ ng = rng.choice(other_groups)
94
+ ni = rng.choice(group_indices[ng])
95
+ else:
96
+ far_lo = max(0, pos_in_group - 3 * positive_window)
97
+ far_hi = min(len(idx), pos_in_group + 3 * positive_window + 1)
98
+ far_candidates = np.setdiff1d(idx, idx[far_lo:far_hi])
99
+ if len(far_candidates) == 0:
100
+ far_candidates = idx
101
+ ni = rng.choice(far_candidates)
102
+
103
+ anchors.append(ai)
104
+ positives.append(pi)
105
+ negatives.append(ni)
106
+
107
+ return np.array(anchors), np.array(positives), np.array(negatives)
108
+
109
+
110
+ def _info_nce_loss(anchor, positive, negative, temperature):
111
+ pos_sim = (anchor * positive).sum(dim=-1) / temperature
112
+ neg_sim = (anchor * negative).sum(dim=-1) / temperature
113
+ logits = torch.stack([pos_sim, neg_sim], dim=-1)
114
+ labels = torch.zeros(len(anchor), dtype=torch.long, device=anchor.device)
115
+ return F.cross_entropy(logits, labels)
116
+
117
+
118
+ def learn_behavior_features(
119
+ matrix_path: str | os.PathLike[str],
120
+ output_path: str | os.PathLike[str],
121
+ *,
122
+ metadata_path: Optional[str | os.PathLike[str]] = None,
123
+ out_dim: int = _DEFAULT_DIM,
124
+ epochs: int = _DEFAULT_EPOCHS,
125
+ lr: float = _DEFAULT_LR,
126
+ positive_window: int = _POSITIVE_WINDOW,
127
+ temperature: float = _TEMPERATURE,
128
+ log_fn: Optional[Callable[[str], None]] = None,
129
+ ) -> dict[str, str]:
130
+ """Train a contrastive projection and write the projected embedding matrix.
131
+
132
+ Returns dict with ``matrix`` and ``metadata`` output paths.
133
+ """
134
+ from singlebehaviorlab.backend.clustering import _load_matrix_metadata
135
+
136
+ matrix_path = str(Path(matrix_path).expanduser().resolve())
137
+ output_path_obj = Path(output_path).expanduser().resolve()
138
+ output_path_obj.parent.mkdir(parents=True, exist_ok=True)
139
+
140
+ metadata_path_str = str(Path(metadata_path).expanduser().resolve()) if metadata_path else None
141
+
142
+ def _log(msg: str) -> None:
143
+ if log_fn:
144
+ log_fn(msg)
145
+
146
+ matrix_df, metadata_df = _load_matrix_metadata(matrix_path, metadata_path_str)
147
+ X = matrix_df.T
148
+ embeddings = X.values.astype(np.float32)
149
+ n_samples, in_dim = embeddings.shape
150
+ _log(f"Loaded {n_samples} embeddings ({in_dim}-dim)")
151
+
152
+ if metadata_df is None:
153
+ metadata_df = pd.DataFrame({"snippet": X.index, "group": "video_0"})
154
+
155
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ model = _ProjectionHead(in_dim, out_dim).to(device)
157
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
158
+ all_emb = torch.from_numpy(embeddings).to(device)
159
+
160
+ rng = np.random.default_rng(42)
161
+ pairs_per_epoch = max(1024, min(n_samples * 4, 65536))
162
+
163
+ _log(f"Training projection head ({in_dim} → {out_dim}) for {epochs} epochs on {device}")
164
+ for epoch in range(epochs):
165
+ anchors, positives, negatives = _build_pairs(metadata_df, pairs_per_epoch, positive_window, rng)
166
+ if len(anchors) == 0:
167
+ _log("No valid pairs found — check metadata has group/video_id column")
168
+ break
169
+ a_emb = model(all_emb[anchors])
170
+ p_emb = model(all_emb[positives])
171
+ n_emb = model(all_emb[negatives])
172
+ loss = _info_nce_loss(a_emb, p_emb, n_emb, temperature)
173
+ optimizer.zero_grad()
174
+ loss.backward()
175
+ optimizer.step()
176
+ if (epoch + 1) % 10 == 0 or epoch == 0:
177
+ _log(f" epoch {epoch + 1}/{epochs} loss={loss.item():.4f}")
178
+
179
+ model.eval()
180
+ with torch.no_grad():
181
+ projected = model(all_emb).cpu().numpy()
182
+ _log(f"Projected embeddings: {projected.shape}")
183
+
184
+ snippet_ids = np.array(X.index.tolist())
185
+ feature_names = np.array([f"behavior_feat_{i}" for i in range(out_dim)])
186
+
187
+ out_matrix = str(output_path_obj)
188
+ if out_matrix.endswith("_matrix.npz"):
189
+ out_metadata = out_matrix.replace("_matrix.npz", "_metadata.npz")
190
+ elif out_matrix.endswith(".npz"):
191
+ out_metadata = out_matrix[:-4] + "_metadata.npz"
192
+ else:
193
+ out_metadata = out_matrix + "_metadata.npz"
194
+
195
+ np.savez_compressed(out_matrix, matrix=projected.T, feature_names=feature_names, snippet_ids=snippet_ids)
196
+ _log(f"Wrote projected matrix: {out_matrix}")
197
+
198
+ if metadata_df is not None:
199
+ np.savez_compressed(out_metadata, metadata=metadata_df.values, columns=np.array(metadata_df.columns))
200
+ _log(f"Wrote metadata: {out_metadata}")
201
+
202
+ return {"matrix": out_matrix, "metadata": out_metadata}
@@ -0,0 +1,158 @@
1
+ """Embedding-based timeline refinement.
2
+
3
+ Uses per-frame embeddings from the inference model to correct predictions
4
+ via semi-supervised label propagation on a nearest-neighbor graph, then
5
+ detects true behavior boundaries from embedding distance spikes.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import numpy as np
11
+ from typing import Optional
12
+
13
+ __all__ = ["refine_with_embeddings"]
14
+
15
+
16
+ def _cosine_distance_adjacent(embeddings: np.ndarray) -> np.ndarray:
17
+ norms = np.maximum(np.linalg.norm(embeddings, axis=1, keepdims=True), 1e-8)
18
+ normed = embeddings / norms
19
+ return 1.0 - np.sum(normed[:-1] * normed[1:], axis=1)
20
+
21
+
22
+ def _detect_boundaries(distances: np.ndarray, threshold_factor: float) -> list[int]:
23
+ if len(distances) == 0:
24
+ return []
25
+ median = float(np.median(distances))
26
+ mad = float(np.median(np.abs(distances - median)))
27
+ threshold = median + threshold_factor * max(mad, 1e-6)
28
+ boundaries = [0]
29
+ for i, d in enumerate(distances):
30
+ if d > threshold:
31
+ boundaries.append(i + 1)
32
+ return boundaries
33
+
34
+
35
+ def _majority_label(labels: np.ndarray, weights: Optional[np.ndarray] = None) -> int:
36
+ valid_mask = labels >= 0
37
+ valid = labels[valid_mask]
38
+ if len(valid) == 0:
39
+ return -1
40
+ if weights is not None:
41
+ w = weights[valid_mask]
42
+ counts: dict[int, float] = {}
43
+ for lbl, wt in zip(valid, w):
44
+ counts[int(lbl)] = counts.get(int(lbl), 0.0) + float(wt)
45
+ return max(counts, key=counts.get)
46
+ vals, cnts = np.unique(valid, return_counts=True)
47
+ return int(vals[np.argmax(cnts)])
48
+
49
+
50
+ def _label_propagation_correction(
51
+ frame_labels: np.ndarray,
52
+ frame_embeddings: np.ndarray,
53
+ frame_confidences: np.ndarray,
54
+ confidence_threshold: float,
55
+ ) -> np.ndarray:
56
+ from sklearn.semi_supervised import LabelSpreading
57
+
58
+ n_frames = len(frame_labels)
59
+ labels_for_propagation = frame_labels.copy()
60
+
61
+ for i in range(n_frames):
62
+ if frame_confidences[i] < confidence_threshold:
63
+ labels_for_propagation[i] = -1
64
+
65
+ n_labeled = np.sum(labels_for_propagation >= 0)
66
+ if n_labeled < 2 or n_labeled == n_frames:
67
+ return frame_labels.copy()
68
+
69
+ n_neighbors = min(7, n_frames - 1)
70
+ lp = LabelSpreading(kernel="knn", n_neighbors=n_neighbors, max_iter=30, alpha=0.2)
71
+ lp.fit(frame_embeddings, labels_for_propagation)
72
+ propagated = lp.transduction_
73
+
74
+ result = frame_labels.copy()
75
+ for i in range(n_frames):
76
+ if frame_confidences[i] < confidence_threshold and propagated[i] >= 0:
77
+ result[i] = int(propagated[i])
78
+
79
+ return result
80
+
81
+
82
+ def refine_with_embeddings(
83
+ frame_labels: np.ndarray,
84
+ frame_embeddings: np.ndarray,
85
+ frame_confidences: Optional[np.ndarray] = None,
86
+ n_classes: int = 0,
87
+ boundary_sensitivity: float = 1.5,
88
+ min_segment_frames: int = 3,
89
+ confidence_threshold: float = 0.7,
90
+ ) -> np.ndarray:
91
+ """Refine per-frame predictions using label propagation and boundary detection.
92
+
93
+ High-confidence predictions seed a nearest-neighbor graph. Labels
94
+ propagate to uncertain frames through embedding similarity. Boundary
95
+ detection then snaps segment edges to real embedding transitions.
96
+ """
97
+ n_frames = len(frame_labels)
98
+ if n_frames < 4 or frame_embeddings.shape[0] != n_frames:
99
+ return frame_labels.copy()
100
+
101
+ if frame_confidences is None:
102
+ frame_confidences = np.ones(n_frames, dtype=np.float32)
103
+
104
+ corrected = _label_propagation_correction(
105
+ frame_labels, frame_embeddings, frame_confidences, confidence_threshold,
106
+ )
107
+
108
+ distances = _cosine_distance_adjacent(frame_embeddings)
109
+ boundaries = _detect_boundaries(distances, boundary_sensitivity)
110
+ boundaries.append(n_frames)
111
+
112
+ refined = corrected.copy()
113
+ segments: list[tuple[int, int]] = []
114
+ for i in range(len(boundaries) - 1):
115
+ start, end = boundaries[i], boundaries[i + 1]
116
+ if end <= start:
117
+ continue
118
+ majority = _majority_label(corrected[start:end], frame_confidences[start:end])
119
+ refined[start:end] = majority
120
+ segments.append((start, end))
121
+
122
+ changed = True
123
+ while changed:
124
+ changed = False
125
+ new_segments = []
126
+ i = 0
127
+ while i < len(segments):
128
+ start, end = segments[i]
129
+ if (end - start) < min_segment_frames and len(segments) > 1:
130
+ mean_emb = frame_embeddings[start:end].mean(axis=0)
131
+ mean_emb /= max(np.linalg.norm(mean_emb), 1e-8)
132
+ best_sim, merge_with = -1.0, -1
133
+ for j in [i - 1, i + 1]:
134
+ if 0 <= j < len(segments):
135
+ ns, ne = segments[j]
136
+ ne_emb = frame_embeddings[ns:ne].mean(axis=0)
137
+ ne_emb /= max(np.linalg.norm(ne_emb), 1e-8)
138
+ sim = float(np.dot(mean_emb, ne_emb))
139
+ if sim > best_sim:
140
+ best_sim, merge_with = sim, j
141
+ if merge_with >= 0:
142
+ ms, me = segments[merge_with]
143
+ ms2, me2 = min(start, ms), max(end, me)
144
+ majority = _majority_label(corrected[ms2:me2], frame_confidences[ms2:me2])
145
+ refined[ms2:me2] = majority
146
+ if merge_with < i:
147
+ new_segments[-1] = (ms2, me2)
148
+ else:
149
+ new_segments.append((ms2, me2))
150
+ i += 1
151
+ changed = True
152
+ i += 1
153
+ continue
154
+ new_segments.append((start, end))
155
+ i += 1
156
+ segments = new_segments
157
+
158
+ return refined
@@ -40,6 +40,7 @@ class RegistrationParams:
40
40
  clip_length_frames: int = 16
41
41
  step_frames: Optional[int] = None
42
42
  backbone_model: str = "videoprism_public_v1_base"
43
+ flip_invariant: bool = False
43
44
  experiment_name: Optional[str] = None
44
45
 
45
46
  @property
@@ -71,6 +72,7 @@ def _extract_embedding(
71
72
  backbone: VideoPrismBackbone,
72
73
  frames: np.ndarray,
73
74
  target_size: int,
75
+ flip_invariant: bool = False,
74
76
  ) -> Optional[np.ndarray]:
75
77
  try:
76
78
  resized = np.array([cv2.resize(f, (target_size, target_size)) for f in frames])
@@ -79,6 +81,12 @@ def _extract_embedding(
79
81
  with torch.no_grad():
80
82
  tokens = backbone(tensor)
81
83
  embedding = tokens.mean(dim=1).squeeze(0).cpu().numpy()
84
+ if flip_invariant:
85
+ embs = [embedding]
86
+ for dims in [[-1], [-2], [-1, -2]]:
87
+ t_flip = torch.flip(tensor, dims=dims)
88
+ embs.append(backbone(t_flip).mean(dim=1).squeeze(0).cpu().numpy())
89
+ embedding = np.mean(embs, axis=0)
82
90
  return embedding.astype(np.float32)
83
91
  except Exception:
84
92
  return None
@@ -178,7 +186,7 @@ def run_registration(
178
186
  if frames is None or len(frames) == 0:
179
187
  _log(f"Skipping {os.path.basename(clip_path)}: no frames")
180
188
  continue
181
- embedding = _extract_embedding(backbone, frames, params.target_size)
189
+ embedding = _extract_embedding(backbone, frames, params.target_size, params.flip_invariant)
182
190
  del frames
183
191
  if embedding is None:
184
192
  _log(f"Skipping {os.path.basename(clip_path)}: embedding failed")
@@ -178,6 +178,10 @@ def _build_parser() -> argparse.ArgumentParser:
178
178
  "--no-clahe", dest="clahe", action="store_false", default=None,
179
179
  help="Disable CLAHE contrast normalization.",
180
180
  )
181
+ register_parser.add_argument(
182
+ "--flip-invariant", action="store_true",
183
+ help="Average original + horizontally flipped embeddings to remove facing-direction bias. 2x extraction time.",
184
+ )
181
185
  _add_common_runtime_flags(register_parser)
182
186
 
183
187
  segment_parser = subparsers.add_parser(
@@ -347,6 +351,8 @@ def cmd_register(args: argparse.Namespace) -> int:
347
351
  params.target_fps = int(args.target_fps)
348
352
  if args.clahe is False:
349
353
  params.normalization_method = "None"
354
+ if args.flip_invariant:
355
+ params.flip_invariant = True
350
356
 
351
357
  bar = {"pbar": None}
352
358