foodforthought-cli 0.2.7__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. ate/__init__.py +6 -0
  2. ate/__main__.py +16 -0
  3. ate/auth/__init__.py +1 -0
  4. ate/auth/device_flow.py +141 -0
  5. ate/auth/token_store.py +96 -0
  6. ate/behaviors/__init__.py +100 -0
  7. ate/behaviors/approach.py +399 -0
  8. ate/behaviors/common.py +686 -0
  9. ate/behaviors/tree.py +454 -0
  10. ate/cli.py +855 -3995
  11. ate/client.py +90 -0
  12. ate/commands/__init__.py +168 -0
  13. ate/commands/auth.py +389 -0
  14. ate/commands/bridge.py +448 -0
  15. ate/commands/data.py +185 -0
  16. ate/commands/deps.py +111 -0
  17. ate/commands/generate.py +384 -0
  18. ate/commands/memory.py +907 -0
  19. ate/commands/parts.py +166 -0
  20. ate/commands/primitive.py +399 -0
  21. ate/commands/protocol.py +288 -0
  22. ate/commands/recording.py +524 -0
  23. ate/commands/repo.py +154 -0
  24. ate/commands/simulation.py +291 -0
  25. ate/commands/skill.py +303 -0
  26. ate/commands/skills.py +487 -0
  27. ate/commands/team.py +147 -0
  28. ate/commands/workflow.py +271 -0
  29. ate/detection/__init__.py +38 -0
  30. ate/detection/base.py +142 -0
  31. ate/detection/color_detector.py +399 -0
  32. ate/detection/trash_detector.py +322 -0
  33. ate/drivers/__init__.py +39 -0
  34. ate/drivers/ble_transport.py +405 -0
  35. ate/drivers/mechdog.py +942 -0
  36. ate/drivers/wifi_camera.py +477 -0
  37. ate/interfaces/__init__.py +187 -0
  38. ate/interfaces/base.py +273 -0
  39. ate/interfaces/body.py +267 -0
  40. ate/interfaces/detection.py +282 -0
  41. ate/interfaces/locomotion.py +422 -0
  42. ate/interfaces/manipulation.py +408 -0
  43. ate/interfaces/navigation.py +389 -0
  44. ate/interfaces/perception.py +362 -0
  45. ate/interfaces/sensors.py +247 -0
  46. ate/interfaces/types.py +371 -0
  47. ate/llm_proxy.py +239 -0
  48. ate/mcp_server.py +387 -0
  49. ate/memory/__init__.py +35 -0
  50. ate/memory/cloud.py +244 -0
  51. ate/memory/context.py +269 -0
  52. ate/memory/embeddings.py +184 -0
  53. ate/memory/export.py +26 -0
  54. ate/memory/merge.py +146 -0
  55. ate/memory/migrate/__init__.py +34 -0
  56. ate/memory/migrate/base.py +89 -0
  57. ate/memory/migrate/pipeline.py +189 -0
  58. ate/memory/migrate/sources/__init__.py +13 -0
  59. ate/memory/migrate/sources/chroma.py +170 -0
  60. ate/memory/migrate/sources/pinecone.py +120 -0
  61. ate/memory/migrate/sources/qdrant.py +110 -0
  62. ate/memory/migrate/sources/weaviate.py +160 -0
  63. ate/memory/reranker.py +353 -0
  64. ate/memory/search.py +26 -0
  65. ate/memory/store.py +548 -0
  66. ate/recording/__init__.py +83 -0
  67. ate/recording/demonstration.py +378 -0
  68. ate/recording/session.py +415 -0
  69. ate/recording/upload.py +304 -0
  70. ate/recording/visual.py +416 -0
  71. ate/recording/wrapper.py +95 -0
  72. ate/robot/__init__.py +221 -0
  73. ate/robot/agentic_servo.py +856 -0
  74. ate/robot/behaviors.py +493 -0
  75. ate/robot/ble_capture.py +1000 -0
  76. ate/robot/ble_enumerate.py +506 -0
  77. ate/robot/calibration.py +668 -0
  78. ate/robot/calibration_state.py +388 -0
  79. ate/robot/commands.py +3735 -0
  80. ate/robot/direction_calibration.py +554 -0
  81. ate/robot/discovery.py +441 -0
  82. ate/robot/introspection.py +330 -0
  83. ate/robot/llm_system_id.py +654 -0
  84. ate/robot/locomotion_calibration.py +508 -0
  85. ate/robot/manager.py +270 -0
  86. ate/robot/marker_generator.py +611 -0
  87. ate/robot/perception.py +502 -0
  88. ate/robot/primitives.py +614 -0
  89. ate/robot/profiles.py +281 -0
  90. ate/robot/registry.py +322 -0
  91. ate/robot/servo_mapper.py +1153 -0
  92. ate/robot/skill_upload.py +675 -0
  93. ate/robot/target_calibration.py +500 -0
  94. ate/robot/teach.py +515 -0
  95. ate/robot/types.py +242 -0
  96. ate/robot/visual_labeler.py +1048 -0
  97. ate/robot/visual_servo_loop.py +494 -0
  98. ate/robot/visual_servoing.py +570 -0
  99. ate/robot/visual_system_id.py +906 -0
  100. ate/transports/__init__.py +121 -0
  101. ate/transports/base.py +394 -0
  102. ate/transports/ble.py +405 -0
  103. ate/transports/hybrid.py +444 -0
  104. ate/transports/serial.py +345 -0
  105. ate/urdf/__init__.py +30 -0
  106. ate/urdf/capture.py +582 -0
  107. ate/urdf/cloud.py +491 -0
  108. ate/urdf/collision.py +271 -0
  109. ate/urdf/commands.py +708 -0
  110. ate/urdf/depth.py +360 -0
  111. ate/urdf/inertial.py +312 -0
  112. ate/urdf/kinematics.py +330 -0
  113. ate/urdf/lifting.py +415 -0
  114. ate/urdf/meshing.py +300 -0
  115. ate/urdf/models/__init__.py +110 -0
  116. ate/urdf/models/depth_anything.py +253 -0
  117. ate/urdf/models/sam2.py +324 -0
  118. ate/urdf/motion_analysis.py +396 -0
  119. ate/urdf/pipeline.py +468 -0
  120. ate/urdf/scale.py +256 -0
  121. ate/urdf/scan_session.py +411 -0
  122. ate/urdf/segmentation.py +299 -0
  123. ate/urdf/synthesis.py +319 -0
  124. ate/urdf/topology.py +336 -0
  125. ate/urdf/validation.py +371 -0
  126. {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/METADATA +9 -1
  127. foodforthought_cli-0.3.0.dist-info/RECORD +166 -0
  128. {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/WHEEL +1 -1
  129. foodforthought_cli-0.2.7.dist-info/RECORD +0 -44
  130. {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/entry_points.txt +0 -0
  131. {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,110 @@
1
+ """
2
+ Model loading and caching for URDF scan pipeline.
3
+
4
+ Provides lazy loading and caching for large ML models:
5
+ - SAM 2 (Segment Anything Model 2) for temporal segmentation
6
+ - Depth Anything V2 for metric depth estimation
7
+ """
8
+
9
+ from typing import Optional
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ModelCache:
16
+ """
17
+ Singleton cache for expensive ML models.
18
+
19
+ Models are loaded lazily on first access and cached for reuse.
20
+ This prevents redundant loading when processing multiple sessions.
21
+ """
22
+
23
+ _sam2_model = None
24
+ _sam2_predictor = None
25
+ _depth_model = None
26
+ _device: str = "cpu"
27
+
28
+ @classmethod
29
+ def set_device(cls, device: str) -> None:
30
+ """Set the device for model inference (cuda or cpu)."""
31
+ cls._device = device
32
+ # Clear cached models if device changes
33
+ if cls._sam2_model is not None or cls._depth_model is not None:
34
+ logger.info(f"Device changed to {device}, clearing model cache")
35
+ cls.clear()
36
+
37
+ @classmethod
38
+ def get_device(cls) -> str:
39
+ """Get the current device setting."""
40
+ return cls._device
41
+
42
+ @classmethod
43
+ def get_sam2(cls):
44
+ """
45
+ Get SAM 2 model, loading if necessary.
46
+
47
+ Returns:
48
+ SAM 2 predictor ready for video inference
49
+ """
50
+ if cls._sam2_predictor is None:
51
+ from .sam2 import load_sam2_predictor
52
+ logger.info(f"Loading SAM 2 model on {cls._device}...")
53
+ cls._sam2_predictor = load_sam2_predictor(cls._device)
54
+ logger.info("SAM 2 model loaded")
55
+ return cls._sam2_predictor
56
+
57
+ @classmethod
58
+ def get_depth_model(cls):
59
+ """
60
+ Get Depth Anything V2 model, loading if necessary.
61
+
62
+ Returns:
63
+ Depth Anything V2 model ready for inference
64
+ """
65
+ if cls._depth_model is None:
66
+ from .depth_anything import load_depth_model
67
+ logger.info(f"Loading Depth Anything V2 on {cls._device}...")
68
+ cls._depth_model = load_depth_model(cls._device)
69
+ logger.info("Depth Anything V2 loaded")
70
+ return cls._depth_model
71
+
72
+ @classmethod
73
+ def clear(cls) -> None:
74
+ """Clear all cached models to free memory."""
75
+ cls._sam2_model = None
76
+ cls._sam2_predictor = None
77
+ cls._depth_model = None
78
+ logger.info("Model cache cleared")
79
+
80
+ @classmethod
81
+ def is_loaded(cls, model_name: str) -> bool:
82
+ """Check if a specific model is loaded."""
83
+ if model_name == "sam2":
84
+ return cls._sam2_predictor is not None
85
+ elif model_name == "depth":
86
+ return cls._depth_model is not None
87
+ return False
88
+
89
+
90
+ def check_cuda_available() -> bool:
91
+ """Check if CUDA is available for GPU acceleration."""
92
+ try:
93
+ import torch
94
+ return torch.cuda.is_available()
95
+ except ImportError:
96
+ return False
97
+
98
+
99
+ def get_recommended_device() -> str:
100
+ """Get the recommended device based on available hardware."""
101
+ if check_cuda_available():
102
+ return "cuda"
103
+ return "cpu"
104
+
105
+
106
+ __all__ = [
107
+ "ModelCache",
108
+ "check_cuda_available",
109
+ "get_recommended_device",
110
+ ]
@@ -0,0 +1,253 @@
1
+ """
2
+ Depth Anything V2 wrapper for metric depth estimation.
3
+
4
+ This module provides:
5
+ - load_depth_model: Load the Depth Anything V2 model
6
+ - DepthEstimator: Wrapper for depth inference
7
+
8
+ When Depth Anything V2 is not installed, a mock implementation
9
+ is used that generates synthetic depth maps for testing.
10
+
11
+ Installation:
12
+ pip install depth-anything-v2
13
+
14
+ References:
15
+ https://github.com/DepthAnything/Depth-Anything-V2
16
+ """
17
+
18
+ import logging
19
+ from typing import Optional, Tuple
20
+ from pathlib import Path
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ try:
25
+ import numpy as np
26
+ NUMPY_AVAILABLE = True
27
+ except ImportError:
28
+ NUMPY_AVAILABLE = False
29
+ np = None
30
+
31
+ try:
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from depth_anything_v2.dpt import DepthAnythingV2
35
+ DEPTH_ANYTHING_AVAILABLE = True
36
+ except ImportError:
37
+ DEPTH_ANYTHING_AVAILABLE = False
38
+ torch = None
39
+
40
+
41
+ class DepthEstimator:
42
+ """
43
+ Wrapper for Depth Anything V2 model.
44
+
45
+ Provides metric depth estimation from single RGB images.
46
+ """
47
+
48
+ def __init__(self, model, device: str = "cpu"):
49
+ """
50
+ Initialize estimator wrapper.
51
+
52
+ Args:
53
+ model: Depth Anything V2 model instance
54
+ device: Compute device
55
+ """
56
+ self.model = model
57
+ self.device = device
58
+
59
+ def estimate(self, image: "np.ndarray") -> "np.ndarray":
60
+ """
61
+ Estimate depth from an RGB image.
62
+
63
+ Args:
64
+ image: BGR image from OpenCV (H, W, 3)
65
+
66
+ Returns:
67
+ Depth map (H, W) in relative units
68
+ """
69
+ import cv2
70
+
71
+ # Convert BGR to RGB
72
+ rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
73
+
74
+ # Prepare input
75
+ h, w = rgb.shape[:2]
76
+
77
+ # Resize to model input size (typically 518x518 for ViT)
78
+ input_size = 518
79
+ rgb_resized = cv2.resize(rgb, (input_size, input_size))
80
+
81
+ # Normalize
82
+ rgb_tensor = torch.from_numpy(rgb_resized).float().permute(2, 0, 1)
83
+ rgb_tensor = rgb_tensor / 255.0
84
+ rgb_tensor = rgb_tensor.unsqueeze(0).to(self.device)
85
+
86
+ # Inference
87
+ with torch.no_grad():
88
+ depth = self.model(rgb_tensor)
89
+
90
+ # Resize back to original
91
+ depth = F.interpolate(
92
+ depth.unsqueeze(1),
93
+ size=(h, w),
94
+ mode="bilinear",
95
+ align_corners=False,
96
+ )
97
+
98
+ depth_np = depth.squeeze().cpu().numpy()
99
+ return depth_np
100
+
101
+ def estimate_metric(
102
+ self,
103
+ image: "np.ndarray",
104
+ scale_factor: float = 1.0,
105
+ ) -> "np.ndarray":
106
+ """
107
+ Estimate metric depth with scale correction.
108
+
109
+ Args:
110
+ image: BGR image from OpenCV
111
+ scale_factor: Scale factor to convert to meters
112
+
113
+ Returns:
114
+ Depth map (H, W) in meters
115
+ """
116
+ depth = self.estimate(image)
117
+ return depth * scale_factor
118
+
119
+
120
+ class MockDepthEstimator:
121
+ """
122
+ Mock depth estimator for testing without the real model.
123
+
124
+ Generates synthetic depth maps based on image brightness,
125
+ simulating closer objects being brighter (like typical indoor lighting).
126
+ """
127
+
128
+ def __init__(self, device: str = "cpu"):
129
+ """Initialize mock estimator."""
130
+ self.device = device
131
+
132
+ def estimate(self, image: "np.ndarray") -> "np.ndarray":
133
+ """
134
+ Generate mock depth from image brightness.
135
+
136
+ Args:
137
+ image: BGR image
138
+
139
+ Returns:
140
+ Mock depth map
141
+ """
142
+ import cv2
143
+
144
+ # Convert to grayscale
145
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32)
146
+
147
+ # Invert: brighter = closer = smaller depth
148
+ depth = 255.0 - gray
149
+
150
+ # Normalize to [0.5, 5.0] meters range (typical indoor scene)
151
+ depth = 0.5 + (depth / 255.0) * 4.5
152
+
153
+ # Add some noise for realism
154
+ noise = np.random.normal(0, 0.1, depth.shape).astype(np.float32)
155
+ depth = depth + noise
156
+ depth = np.clip(depth, 0.3, 6.0)
157
+
158
+ return depth
159
+
160
+ def estimate_metric(
161
+ self,
162
+ image: "np.ndarray",
163
+ scale_factor: float = 1.0,
164
+ ) -> "np.ndarray":
165
+ """Generate mock metric depth."""
166
+ depth = self.estimate(image)
167
+ # Scale factor adjusts the overall scale
168
+ return depth * scale_factor
169
+
170
+
171
+ def load_depth_model(device: str = "cpu") -> DepthEstimator:
172
+ """
173
+ Load Depth Anything V2 model.
174
+
175
+ Args:
176
+ device: Compute device ("cuda" or "cpu")
177
+
178
+ Returns:
179
+ DepthEstimator instance (or mock if unavailable)
180
+ """
181
+ if DEPTH_ANYTHING_AVAILABLE:
182
+ try:
183
+ logger.info("Loading Depth Anything V2 model...")
184
+
185
+ # Model configurations
186
+ model_configs = {
187
+ "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
188
+ "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
189
+ "vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
190
+ }
191
+
192
+ # Try to find checkpoint
193
+ checkpoint_paths = [
194
+ Path.home() / ".cache" / "depth_anything_v2" / "depth_anything_v2_vitl.pth",
195
+ Path.home() / ".cache" / "depth_anything_v2" / "depth_anything_v2_vitb.pth",
196
+ Path.home() / ".cache" / "depth_anything_v2" / "depth_anything_v2_vits.pth",
197
+ Path("./checkpoints/depth_anything_v2_vitb.pth"),
198
+ ]
199
+
200
+ checkpoint = None
201
+ config_name = None
202
+ for path in checkpoint_paths:
203
+ if path.exists():
204
+ checkpoint = str(path)
205
+ # Infer config from filename
206
+ if "vitl" in path.name:
207
+ config_name = "vitl"
208
+ elif "vitb" in path.name:
209
+ config_name = "vitb"
210
+ else:
211
+ config_name = "vits"
212
+ break
213
+
214
+ if checkpoint is None:
215
+ logger.warning(
216
+ "Depth Anything V2 checkpoint not found. Using mock estimator. "
217
+ "Download from: https://github.com/DepthAnything/Depth-Anything-V2"
218
+ )
219
+ return MockDepthEstimator(device)
220
+
221
+ # Build model
222
+ config = model_configs[config_name]
223
+ model = DepthAnythingV2(**config)
224
+ model.load_state_dict(torch.load(checkpoint, map_location=device))
225
+ model.to(device)
226
+ model.eval()
227
+
228
+ logger.info(f"Loaded Depth Anything V2 ({config_name}) on {device}")
229
+ return DepthEstimator(model, device)
230
+
231
+ except Exception as e:
232
+ logger.warning(f"Failed to load Depth Anything V2: {e}. Using mock estimator.")
233
+ return MockDepthEstimator(device)
234
+
235
+ else:
236
+ logger.warning(
237
+ "Depth Anything V2 not installed. Using mock estimator. "
238
+ "Install with: pip install depth-anything-v2"
239
+ )
240
+ return MockDepthEstimator(device)
241
+
242
+
243
+ def check_depth_anything_available() -> bool:
244
+ """Check if Depth Anything V2 is available."""
245
+ return DEPTH_ANYTHING_AVAILABLE
246
+
247
+
248
+ __all__ = [
249
+ "DepthEstimator",
250
+ "MockDepthEstimator",
251
+ "load_depth_model",
252
+ "check_depth_anything_available",
253
+ ]
@@ -0,0 +1,324 @@
1
+ """
2
+ SAM 2 (Segment Anything Model 2) wrapper for temporal segmentation.
3
+
4
+ This module provides:
5
+ - load_sam2_predictor: Load the SAM 2 video predictor
6
+ - SAM2VideoPredictor: Wrapper for video segmentation
7
+
8
+ When SAM 2 is not installed, a mock implementation is used that
9
+ generates simple masks based on color clustering for testing.
10
+
11
+ Installation:
12
+ pip install segment-anything-2
13
+
14
+ References:
15
+ https://github.com/facebookresearch/segment-anything-2
16
+ """
17
+
18
+ import logging
19
+ from typing import Dict, List, Optional, Tuple
20
+ from pathlib import Path
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ try:
25
+ import numpy as np
26
+ NUMPY_AVAILABLE = True
27
+ except ImportError:
28
+ NUMPY_AVAILABLE = False
29
+ np = None
30
+
31
+ try:
32
+ import torch
33
+ from sam2.build_sam import build_sam2_video_predictor
34
+ SAM2_AVAILABLE = True
35
+ except ImportError:
36
+ SAM2_AVAILABLE = False
37
+ torch = None
38
+
39
+
40
+ class SAM2VideoPredictor:
41
+ """
42
+ Wrapper for SAM 2 video segmentation.
43
+
44
+ Provides temporal mask propagation from initial point prompts.
45
+ """
46
+
47
+ def __init__(self, predictor, device: str = "cpu"):
48
+ """
49
+ Initialize predictor wrapper.
50
+
51
+ Args:
52
+ predictor: SAM 2 video predictor instance
53
+ device: Compute device
54
+ """
55
+ self.predictor = predictor
56
+ self.device = device
57
+ self._state = None
58
+
59
+ def initialize(self, video_path: str) -> None:
60
+ """
61
+ Initialize predictor with a video.
62
+
63
+ Args:
64
+ video_path: Path to video file
65
+ """
66
+ import cv2
67
+ import tempfile
68
+ import shutil
69
+ from pathlib import Path
70
+
71
+ # Extract frames using OpenCV (works on Apple Silicon)
72
+ # SAM 2's init_state can use an image directory instead of video
73
+ self._temp_dir = tempfile.mkdtemp(prefix="sam2_frames_")
74
+
75
+ cap = cv2.VideoCapture(video_path)
76
+ frame_idx = 0
77
+ while True:
78
+ ret, frame = cap.read()
79
+ if not ret:
80
+ break
81
+ # Save as JPEG with zero-padded frame numbers
82
+ frame_path = Path(self._temp_dir) / f"{frame_idx:06d}.jpg"
83
+ cv2.imwrite(str(frame_path), frame)
84
+ frame_idx += 1
85
+ cap.release()
86
+
87
+ logger.info(f"Extracted {frame_idx} frames to {self._temp_dir}")
88
+
89
+ # Use image directory mode instead of video mode
90
+ self._state = self.predictor.init_state(video_path=self._temp_dir)
91
+ logger.info(f"SAM 2 initialized with {frame_idx} frames")
92
+
93
+ def add_point_prompt(
94
+ self,
95
+ frame_idx: int,
96
+ obj_id: int,
97
+ point: Tuple[float, float],
98
+ label: int = 1,
99
+ ) -> "np.ndarray":
100
+ """
101
+ Add a point prompt to initialize tracking.
102
+
103
+ Args:
104
+ frame_idx: Frame index (usually 0)
105
+ obj_id: Unique object ID for this link
106
+ point: (x, y) coordinates in frame
107
+ label: 1 for foreground, 0 for background
108
+
109
+ Returns:
110
+ Initial mask for the object
111
+ """
112
+ points = np.array([[point[0], point[1]]], dtype=np.float32)
113
+ labels = np.array([label], dtype=np.int32)
114
+
115
+ _, _, mask_logits = self.predictor.add_new_points_or_box(
116
+ inference_state=self._state,
117
+ frame_idx=frame_idx,
118
+ obj_id=obj_id,
119
+ points=points,
120
+ labels=labels,
121
+ )
122
+
123
+ # Convert logits to binary mask
124
+ mask = (mask_logits > 0).cpu().numpy().squeeze()
125
+ return mask
126
+
127
+ def propagate(self) -> Dict[int, Dict[int, "np.ndarray"]]:
128
+ """
129
+ Propagate masks across all video frames.
130
+
131
+ Returns:
132
+ Dict mapping frame_idx -> {obj_id: mask}
133
+ """
134
+ results = {}
135
+
136
+ for frame_idx, obj_ids, mask_logits in self.predictor.propagate_in_video(
137
+ self._state
138
+ ):
139
+ frame_masks = {}
140
+ for i, obj_id in enumerate(obj_ids):
141
+ mask = (mask_logits[i] > 0).cpu().numpy().squeeze()
142
+ frame_masks[obj_id] = mask
143
+ results[frame_idx] = frame_masks
144
+
145
+ logger.info(f"SAM 2 propagated masks across {len(results)} frames")
146
+ return results
147
+
148
+ def reset(self) -> None:
149
+ """Reset predictor state."""
150
+ if self._state is not None:
151
+ self.predictor.reset_state(self._state)
152
+ self._state = None
153
+
154
+
155
+ class MockSAM2Predictor:
156
+ """
157
+ Mock SAM 2 predictor for testing without the real model.
158
+
159
+ Uses simple color-based region growing from seed points.
160
+ This is NOT accurate but allows pipeline testing.
161
+ """
162
+
163
+ def __init__(self, device: str = "cpu"):
164
+ """Initialize mock predictor."""
165
+ self.device = device
166
+ self._video_path: Optional[str] = None
167
+ self._prompts: Dict[int, Tuple[float, float]] = {} # obj_id -> point
168
+ self._frame_count: int = 0
169
+ self._resolution: Tuple[int, int] = (720, 1280)
170
+
171
+ def initialize(self, video_path: str) -> None:
172
+ """Initialize with video."""
173
+ import cv2
174
+
175
+ self._video_path = video_path
176
+ cap = cv2.VideoCapture(video_path)
177
+ self._frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
178
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
179
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
180
+ self._resolution = (height, width)
181
+ cap.release()
182
+
183
+ logger.info(f"Mock SAM 2 initialized: {self._frame_count} frames")
184
+
185
+ def add_point_prompt(
186
+ self,
187
+ frame_idx: int,
188
+ obj_id: int,
189
+ point: Tuple[float, float],
190
+ label: int = 1,
191
+ ) -> "np.ndarray":
192
+ """Add point prompt and return initial mask."""
193
+ self._prompts[obj_id] = point
194
+
195
+ # Generate simple circular mask around point
196
+ mask = self._generate_circular_mask(point, radius=50)
197
+ return mask
198
+
199
+ def _generate_circular_mask(
200
+ self,
201
+ center: Tuple[float, float],
202
+ radius: int = 50,
203
+ ) -> "np.ndarray":
204
+ """Generate a circular mask around a point."""
205
+ h, w = self._resolution
206
+ mask = np.zeros((h, w), dtype=bool)
207
+
208
+ cx, cy = int(center[0]), int(center[1])
209
+ y, x = np.ogrid[:h, :w]
210
+ dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
211
+ mask[dist <= radius] = True
212
+
213
+ return mask
214
+
215
+ def propagate(self) -> Dict[int, Dict[int, "np.ndarray"]]:
216
+ """
217
+ Generate mock masks for all frames.
218
+
219
+ Creates slightly varying masks to simulate motion tracking.
220
+ """
221
+ results = {}
222
+
223
+ for frame_idx in range(self._frame_count):
224
+ frame_masks = {}
225
+
226
+ for obj_id, point in self._prompts.items():
227
+ # Add small random offset to simulate motion
228
+ offset_x = np.random.uniform(-5, 5)
229
+ offset_y = np.random.uniform(-5, 5)
230
+ moved_point = (point[0] + offset_x, point[1] + offset_y)
231
+
232
+ # Vary radius slightly
233
+ radius = 50 + np.random.randint(-10, 10)
234
+ mask = self._generate_circular_mask(moved_point, radius)
235
+ frame_masks[obj_id] = mask
236
+
237
+ results[frame_idx] = frame_masks
238
+
239
+ logger.info(f"Mock SAM 2 generated masks for {len(results)} frames")
240
+ return results
241
+
242
+ def reset(self) -> None:
243
+ """Reset state."""
244
+ self._prompts = {}
245
+
246
+
247
+ def load_sam2_predictor(device: str = "cpu") -> SAM2VideoPredictor:
248
+ """
249
+ Load SAM 2 video predictor.
250
+
251
+ Args:
252
+ device: Compute device ("cuda" or "cpu")
253
+
254
+ Returns:
255
+ SAM2VideoPredictor instance (or mock if unavailable)
256
+ """
257
+ if SAM2_AVAILABLE:
258
+ try:
259
+ # Try to load the real model
260
+ logger.info("Loading SAM 2 model...")
261
+
262
+ # SAM 2 model checkpoint paths
263
+ # Users should download from:
264
+ # https://github.com/facebookresearch/segment-anything-2#model-checkpoints
265
+ checkpoint_paths = [
266
+ Path.home() / ".cache" / "sam2" / "sam2_hiera_large.pt",
267
+ Path.home() / ".cache" / "sam2" / "sam2_hiera_base_plus.pt",
268
+ Path.home() / ".cache" / "sam2" / "sam2_hiera_base.pt",
269
+ Path("./checkpoints/sam2_hiera_large.pt"),
270
+ Path("./checkpoints/sam2_hiera_base_plus.pt"),
271
+ ]
272
+
273
+ checkpoint = None
274
+ for path in checkpoint_paths:
275
+ if path.exists():
276
+ checkpoint = str(path)
277
+ break
278
+
279
+ if checkpoint is None:
280
+ logger.warning(
281
+ "SAM 2 checkpoint not found. Using mock predictor. "
282
+ "Download from: https://github.com/facebookresearch/segment-anything-2"
283
+ )
284
+ return MockSAM2Predictor(device)
285
+
286
+ # Select config file based on checkpoint
287
+ if "base_plus" in checkpoint:
288
+ config_file = "sam2_hiera_b+.yaml"
289
+ elif "base" in checkpoint:
290
+ config_file = "sam2_hiera_b.yaml"
291
+ else:
292
+ config_file = "sam2_hiera_l.yaml"
293
+
294
+ predictor = build_sam2_video_predictor(
295
+ config_file=config_file,
296
+ checkpoint=checkpoint,
297
+ device=device,
298
+ )
299
+
300
+ return SAM2VideoPredictor(predictor, device)
301
+
302
+ except Exception as e:
303
+ logger.warning(f"Failed to load SAM 2: {e}. Using mock predictor.")
304
+ return MockSAM2Predictor(device)
305
+
306
+ else:
307
+ logger.warning(
308
+ "SAM 2 not installed. Using mock predictor. "
309
+ "Install with: pip install segment-anything-2"
310
+ )
311
+ return MockSAM2Predictor(device)
312
+
313
+
314
+ def check_sam2_available() -> bool:
315
+ """Check if SAM 2 is available."""
316
+ return SAM2_AVAILABLE
317
+
318
+
319
+ __all__ = [
320
+ "SAM2VideoPredictor",
321
+ "MockSAM2Predictor",
322
+ "load_sam2_predictor",
323
+ "check_sam2_available",
324
+ ]