foodforthought-cli 0.2.7__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ate/__init__.py +6 -0
- ate/__main__.py +16 -0
- ate/auth/__init__.py +1 -0
- ate/auth/device_flow.py +141 -0
- ate/auth/token_store.py +96 -0
- ate/behaviors/__init__.py +100 -0
- ate/behaviors/approach.py +399 -0
- ate/behaviors/common.py +686 -0
- ate/behaviors/tree.py +454 -0
- ate/cli.py +855 -3995
- ate/client.py +90 -0
- ate/commands/__init__.py +168 -0
- ate/commands/auth.py +389 -0
- ate/commands/bridge.py +448 -0
- ate/commands/data.py +185 -0
- ate/commands/deps.py +111 -0
- ate/commands/generate.py +384 -0
- ate/commands/memory.py +907 -0
- ate/commands/parts.py +166 -0
- ate/commands/primitive.py +399 -0
- ate/commands/protocol.py +288 -0
- ate/commands/recording.py +524 -0
- ate/commands/repo.py +154 -0
- ate/commands/simulation.py +291 -0
- ate/commands/skill.py +303 -0
- ate/commands/skills.py +487 -0
- ate/commands/team.py +147 -0
- ate/commands/workflow.py +271 -0
- ate/detection/__init__.py +38 -0
- ate/detection/base.py +142 -0
- ate/detection/color_detector.py +399 -0
- ate/detection/trash_detector.py +322 -0
- ate/drivers/__init__.py +39 -0
- ate/drivers/ble_transport.py +405 -0
- ate/drivers/mechdog.py +942 -0
- ate/drivers/wifi_camera.py +477 -0
- ate/interfaces/__init__.py +187 -0
- ate/interfaces/base.py +273 -0
- ate/interfaces/body.py +267 -0
- ate/interfaces/detection.py +282 -0
- ate/interfaces/locomotion.py +422 -0
- ate/interfaces/manipulation.py +408 -0
- ate/interfaces/navigation.py +389 -0
- ate/interfaces/perception.py +362 -0
- ate/interfaces/sensors.py +247 -0
- ate/interfaces/types.py +371 -0
- ate/llm_proxy.py +239 -0
- ate/mcp_server.py +387 -0
- ate/memory/__init__.py +35 -0
- ate/memory/cloud.py +244 -0
- ate/memory/context.py +269 -0
- ate/memory/embeddings.py +184 -0
- ate/memory/export.py +26 -0
- ate/memory/merge.py +146 -0
- ate/memory/migrate/__init__.py +34 -0
- ate/memory/migrate/base.py +89 -0
- ate/memory/migrate/pipeline.py +189 -0
- ate/memory/migrate/sources/__init__.py +13 -0
- ate/memory/migrate/sources/chroma.py +170 -0
- ate/memory/migrate/sources/pinecone.py +120 -0
- ate/memory/migrate/sources/qdrant.py +110 -0
- ate/memory/migrate/sources/weaviate.py +160 -0
- ate/memory/reranker.py +353 -0
- ate/memory/search.py +26 -0
- ate/memory/store.py +548 -0
- ate/recording/__init__.py +83 -0
- ate/recording/demonstration.py +378 -0
- ate/recording/session.py +415 -0
- ate/recording/upload.py +304 -0
- ate/recording/visual.py +416 -0
- ate/recording/wrapper.py +95 -0
- ate/robot/__init__.py +221 -0
- ate/robot/agentic_servo.py +856 -0
- ate/robot/behaviors.py +493 -0
- ate/robot/ble_capture.py +1000 -0
- ate/robot/ble_enumerate.py +506 -0
- ate/robot/calibration.py +668 -0
- ate/robot/calibration_state.py +388 -0
- ate/robot/commands.py +3735 -0
- ate/robot/direction_calibration.py +554 -0
- ate/robot/discovery.py +441 -0
- ate/robot/introspection.py +330 -0
- ate/robot/llm_system_id.py +654 -0
- ate/robot/locomotion_calibration.py +508 -0
- ate/robot/manager.py +270 -0
- ate/robot/marker_generator.py +611 -0
- ate/robot/perception.py +502 -0
- ate/robot/primitives.py +614 -0
- ate/robot/profiles.py +281 -0
- ate/robot/registry.py +322 -0
- ate/robot/servo_mapper.py +1153 -0
- ate/robot/skill_upload.py +675 -0
- ate/robot/target_calibration.py +500 -0
- ate/robot/teach.py +515 -0
- ate/robot/types.py +242 -0
- ate/robot/visual_labeler.py +1048 -0
- ate/robot/visual_servo_loop.py +494 -0
- ate/robot/visual_servoing.py +570 -0
- ate/robot/visual_system_id.py +906 -0
- ate/transports/__init__.py +121 -0
- ate/transports/base.py +394 -0
- ate/transports/ble.py +405 -0
- ate/transports/hybrid.py +444 -0
- ate/transports/serial.py +345 -0
- ate/urdf/__init__.py +30 -0
- ate/urdf/capture.py +582 -0
- ate/urdf/cloud.py +491 -0
- ate/urdf/collision.py +271 -0
- ate/urdf/commands.py +708 -0
- ate/urdf/depth.py +360 -0
- ate/urdf/inertial.py +312 -0
- ate/urdf/kinematics.py +330 -0
- ate/urdf/lifting.py +415 -0
- ate/urdf/meshing.py +300 -0
- ate/urdf/models/__init__.py +110 -0
- ate/urdf/models/depth_anything.py +253 -0
- ate/urdf/models/sam2.py +324 -0
- ate/urdf/motion_analysis.py +396 -0
- ate/urdf/pipeline.py +468 -0
- ate/urdf/scale.py +256 -0
- ate/urdf/scan_session.py +411 -0
- ate/urdf/segmentation.py +299 -0
- ate/urdf/synthesis.py +319 -0
- ate/urdf/topology.py +336 -0
- ate/urdf/validation.py +371 -0
- {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/METADATA +9 -1
- foodforthought_cli-0.3.0.dist-info/RECORD +166 -0
- {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/WHEEL +1 -1
- foodforthought_cli-0.2.7.dist-info/RECORD +0 -44
- {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/entry_points.txt +0 -0
- {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model loading and caching for URDF scan pipeline.
|
|
3
|
+
|
|
4
|
+
Provides lazy loading and caching for large ML models:
|
|
5
|
+
- SAM 2 (Segment Anything Model 2) for temporal segmentation
|
|
6
|
+
- Depth Anything V2 for metric depth estimation
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Optional
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelCache:
|
|
16
|
+
"""
|
|
17
|
+
Singleton cache for expensive ML models.
|
|
18
|
+
|
|
19
|
+
Models are loaded lazily on first access and cached for reuse.
|
|
20
|
+
This prevents redundant loading when processing multiple sessions.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
_sam2_model = None
|
|
24
|
+
_sam2_predictor = None
|
|
25
|
+
_depth_model = None
|
|
26
|
+
_device: str = "cpu"
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def set_device(cls, device: str) -> None:
|
|
30
|
+
"""Set the device for model inference (cuda or cpu)."""
|
|
31
|
+
cls._device = device
|
|
32
|
+
# Clear cached models if device changes
|
|
33
|
+
if cls._sam2_model is not None or cls._depth_model is not None:
|
|
34
|
+
logger.info(f"Device changed to {device}, clearing model cache")
|
|
35
|
+
cls.clear()
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def get_device(cls) -> str:
|
|
39
|
+
"""Get the current device setting."""
|
|
40
|
+
return cls._device
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def get_sam2(cls):
|
|
44
|
+
"""
|
|
45
|
+
Get SAM 2 model, loading if necessary.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
SAM 2 predictor ready for video inference
|
|
49
|
+
"""
|
|
50
|
+
if cls._sam2_predictor is None:
|
|
51
|
+
from .sam2 import load_sam2_predictor
|
|
52
|
+
logger.info(f"Loading SAM 2 model on {cls._device}...")
|
|
53
|
+
cls._sam2_predictor = load_sam2_predictor(cls._device)
|
|
54
|
+
logger.info("SAM 2 model loaded")
|
|
55
|
+
return cls._sam2_predictor
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def get_depth_model(cls):
|
|
59
|
+
"""
|
|
60
|
+
Get Depth Anything V2 model, loading if necessary.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Depth Anything V2 model ready for inference
|
|
64
|
+
"""
|
|
65
|
+
if cls._depth_model is None:
|
|
66
|
+
from .depth_anything import load_depth_model
|
|
67
|
+
logger.info(f"Loading Depth Anything V2 on {cls._device}...")
|
|
68
|
+
cls._depth_model = load_depth_model(cls._device)
|
|
69
|
+
logger.info("Depth Anything V2 loaded")
|
|
70
|
+
return cls._depth_model
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def clear(cls) -> None:
|
|
74
|
+
"""Clear all cached models to free memory."""
|
|
75
|
+
cls._sam2_model = None
|
|
76
|
+
cls._sam2_predictor = None
|
|
77
|
+
cls._depth_model = None
|
|
78
|
+
logger.info("Model cache cleared")
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def is_loaded(cls, model_name: str) -> bool:
|
|
82
|
+
"""Check if a specific model is loaded."""
|
|
83
|
+
if model_name == "sam2":
|
|
84
|
+
return cls._sam2_predictor is not None
|
|
85
|
+
elif model_name == "depth":
|
|
86
|
+
return cls._depth_model is not None
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def check_cuda_available() -> bool:
|
|
91
|
+
"""Check if CUDA is available for GPU acceleration."""
|
|
92
|
+
try:
|
|
93
|
+
import torch
|
|
94
|
+
return torch.cuda.is_available()
|
|
95
|
+
except ImportError:
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_recommended_device() -> str:
|
|
100
|
+
"""Get the recommended device based on available hardware."""
|
|
101
|
+
if check_cuda_available():
|
|
102
|
+
return "cuda"
|
|
103
|
+
return "cpu"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
__all__ = [
|
|
107
|
+
"ModelCache",
|
|
108
|
+
"check_cuda_available",
|
|
109
|
+
"get_recommended_device",
|
|
110
|
+
]
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Depth Anything V2 wrapper for metric depth estimation.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- load_depth_model: Load the Depth Anything V2 model
|
|
6
|
+
- DepthEstimator: Wrapper for depth inference
|
|
7
|
+
|
|
8
|
+
When Depth Anything V2 is not installed, a mock implementation
|
|
9
|
+
is used that generates synthetic depth maps for testing.
|
|
10
|
+
|
|
11
|
+
Installation:
|
|
12
|
+
pip install depth-anything-v2
|
|
13
|
+
|
|
14
|
+
References:
|
|
15
|
+
https://github.com/DepthAnything/Depth-Anything-V2
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Optional, Tuple
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import numpy as np
|
|
26
|
+
NUMPY_AVAILABLE = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
NUMPY_AVAILABLE = False
|
|
29
|
+
np = None
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
import torch
|
|
33
|
+
import torch.nn.functional as F
|
|
34
|
+
from depth_anything_v2.dpt import DepthAnythingV2
|
|
35
|
+
DEPTH_ANYTHING_AVAILABLE = True
|
|
36
|
+
except ImportError:
|
|
37
|
+
DEPTH_ANYTHING_AVAILABLE = False
|
|
38
|
+
torch = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DepthEstimator:
|
|
42
|
+
"""
|
|
43
|
+
Wrapper for Depth Anything V2 model.
|
|
44
|
+
|
|
45
|
+
Provides metric depth estimation from single RGB images.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, model, device: str = "cpu"):
|
|
49
|
+
"""
|
|
50
|
+
Initialize estimator wrapper.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model: Depth Anything V2 model instance
|
|
54
|
+
device: Compute device
|
|
55
|
+
"""
|
|
56
|
+
self.model = model
|
|
57
|
+
self.device = device
|
|
58
|
+
|
|
59
|
+
def estimate(self, image: "np.ndarray") -> "np.ndarray":
|
|
60
|
+
"""
|
|
61
|
+
Estimate depth from an RGB image.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
image: BGR image from OpenCV (H, W, 3)
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Depth map (H, W) in relative units
|
|
68
|
+
"""
|
|
69
|
+
import cv2
|
|
70
|
+
|
|
71
|
+
# Convert BGR to RGB
|
|
72
|
+
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
73
|
+
|
|
74
|
+
# Prepare input
|
|
75
|
+
h, w = rgb.shape[:2]
|
|
76
|
+
|
|
77
|
+
# Resize to model input size (typically 518x518 for ViT)
|
|
78
|
+
input_size = 518
|
|
79
|
+
rgb_resized = cv2.resize(rgb, (input_size, input_size))
|
|
80
|
+
|
|
81
|
+
# Normalize
|
|
82
|
+
rgb_tensor = torch.from_numpy(rgb_resized).float().permute(2, 0, 1)
|
|
83
|
+
rgb_tensor = rgb_tensor / 255.0
|
|
84
|
+
rgb_tensor = rgb_tensor.unsqueeze(0).to(self.device)
|
|
85
|
+
|
|
86
|
+
# Inference
|
|
87
|
+
with torch.no_grad():
|
|
88
|
+
depth = self.model(rgb_tensor)
|
|
89
|
+
|
|
90
|
+
# Resize back to original
|
|
91
|
+
depth = F.interpolate(
|
|
92
|
+
depth.unsqueeze(1),
|
|
93
|
+
size=(h, w),
|
|
94
|
+
mode="bilinear",
|
|
95
|
+
align_corners=False,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
depth_np = depth.squeeze().cpu().numpy()
|
|
99
|
+
return depth_np
|
|
100
|
+
|
|
101
|
+
def estimate_metric(
|
|
102
|
+
self,
|
|
103
|
+
image: "np.ndarray",
|
|
104
|
+
scale_factor: float = 1.0,
|
|
105
|
+
) -> "np.ndarray":
|
|
106
|
+
"""
|
|
107
|
+
Estimate metric depth with scale correction.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
image: BGR image from OpenCV
|
|
111
|
+
scale_factor: Scale factor to convert to meters
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Depth map (H, W) in meters
|
|
115
|
+
"""
|
|
116
|
+
depth = self.estimate(image)
|
|
117
|
+
return depth * scale_factor
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class MockDepthEstimator:
|
|
121
|
+
"""
|
|
122
|
+
Mock depth estimator for testing without the real model.
|
|
123
|
+
|
|
124
|
+
Generates synthetic depth maps based on image brightness,
|
|
125
|
+
simulating closer objects being brighter (like typical indoor lighting).
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(self, device: str = "cpu"):
|
|
129
|
+
"""Initialize mock estimator."""
|
|
130
|
+
self.device = device
|
|
131
|
+
|
|
132
|
+
def estimate(self, image: "np.ndarray") -> "np.ndarray":
|
|
133
|
+
"""
|
|
134
|
+
Generate mock depth from image brightness.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
image: BGR image
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Mock depth map
|
|
141
|
+
"""
|
|
142
|
+
import cv2
|
|
143
|
+
|
|
144
|
+
# Convert to grayscale
|
|
145
|
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32)
|
|
146
|
+
|
|
147
|
+
# Invert: brighter = closer = smaller depth
|
|
148
|
+
depth = 255.0 - gray
|
|
149
|
+
|
|
150
|
+
# Normalize to [0.5, 5.0] meters range (typical indoor scene)
|
|
151
|
+
depth = 0.5 + (depth / 255.0) * 4.5
|
|
152
|
+
|
|
153
|
+
# Add some noise for realism
|
|
154
|
+
noise = np.random.normal(0, 0.1, depth.shape).astype(np.float32)
|
|
155
|
+
depth = depth + noise
|
|
156
|
+
depth = np.clip(depth, 0.3, 6.0)
|
|
157
|
+
|
|
158
|
+
return depth
|
|
159
|
+
|
|
160
|
+
def estimate_metric(
|
|
161
|
+
self,
|
|
162
|
+
image: "np.ndarray",
|
|
163
|
+
scale_factor: float = 1.0,
|
|
164
|
+
) -> "np.ndarray":
|
|
165
|
+
"""Generate mock metric depth."""
|
|
166
|
+
depth = self.estimate(image)
|
|
167
|
+
# Scale factor adjusts the overall scale
|
|
168
|
+
return depth * scale_factor
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def load_depth_model(device: str = "cpu") -> DepthEstimator:
|
|
172
|
+
"""
|
|
173
|
+
Load Depth Anything V2 model.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
device: Compute device ("cuda" or "cpu")
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
DepthEstimator instance (or mock if unavailable)
|
|
180
|
+
"""
|
|
181
|
+
if DEPTH_ANYTHING_AVAILABLE:
|
|
182
|
+
try:
|
|
183
|
+
logger.info("Loading Depth Anything V2 model...")
|
|
184
|
+
|
|
185
|
+
# Model configurations
|
|
186
|
+
model_configs = {
|
|
187
|
+
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
|
|
188
|
+
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
|
|
189
|
+
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
# Try to find checkpoint
|
|
193
|
+
checkpoint_paths = [
|
|
194
|
+
Path.home() / ".cache" / "depth_anything_v2" / "depth_anything_v2_vitl.pth",
|
|
195
|
+
Path.home() / ".cache" / "depth_anything_v2" / "depth_anything_v2_vitb.pth",
|
|
196
|
+
Path.home() / ".cache" / "depth_anything_v2" / "depth_anything_v2_vits.pth",
|
|
197
|
+
Path("./checkpoints/depth_anything_v2_vitb.pth"),
|
|
198
|
+
]
|
|
199
|
+
|
|
200
|
+
checkpoint = None
|
|
201
|
+
config_name = None
|
|
202
|
+
for path in checkpoint_paths:
|
|
203
|
+
if path.exists():
|
|
204
|
+
checkpoint = str(path)
|
|
205
|
+
# Infer config from filename
|
|
206
|
+
if "vitl" in path.name:
|
|
207
|
+
config_name = "vitl"
|
|
208
|
+
elif "vitb" in path.name:
|
|
209
|
+
config_name = "vitb"
|
|
210
|
+
else:
|
|
211
|
+
config_name = "vits"
|
|
212
|
+
break
|
|
213
|
+
|
|
214
|
+
if checkpoint is None:
|
|
215
|
+
logger.warning(
|
|
216
|
+
"Depth Anything V2 checkpoint not found. Using mock estimator. "
|
|
217
|
+
"Download from: https://github.com/DepthAnything/Depth-Anything-V2"
|
|
218
|
+
)
|
|
219
|
+
return MockDepthEstimator(device)
|
|
220
|
+
|
|
221
|
+
# Build model
|
|
222
|
+
config = model_configs[config_name]
|
|
223
|
+
model = DepthAnythingV2(**config)
|
|
224
|
+
model.load_state_dict(torch.load(checkpoint, map_location=device))
|
|
225
|
+
model.to(device)
|
|
226
|
+
model.eval()
|
|
227
|
+
|
|
228
|
+
logger.info(f"Loaded Depth Anything V2 ({config_name}) on {device}")
|
|
229
|
+
return DepthEstimator(model, device)
|
|
230
|
+
|
|
231
|
+
except Exception as e:
|
|
232
|
+
logger.warning(f"Failed to load Depth Anything V2: {e}. Using mock estimator.")
|
|
233
|
+
return MockDepthEstimator(device)
|
|
234
|
+
|
|
235
|
+
else:
|
|
236
|
+
logger.warning(
|
|
237
|
+
"Depth Anything V2 not installed. Using mock estimator. "
|
|
238
|
+
"Install with: pip install depth-anything-v2"
|
|
239
|
+
)
|
|
240
|
+
return MockDepthEstimator(device)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def check_depth_anything_available() -> bool:
|
|
244
|
+
"""Check if Depth Anything V2 is available."""
|
|
245
|
+
return DEPTH_ANYTHING_AVAILABLE
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
__all__ = [
|
|
249
|
+
"DepthEstimator",
|
|
250
|
+
"MockDepthEstimator",
|
|
251
|
+
"load_depth_model",
|
|
252
|
+
"check_depth_anything_available",
|
|
253
|
+
]
|
ate/urdf/models/sam2.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SAM 2 (Segment Anything Model 2) wrapper for temporal segmentation.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- load_sam2_predictor: Load the SAM 2 video predictor
|
|
6
|
+
- SAM2VideoPredictor: Wrapper for video segmentation
|
|
7
|
+
|
|
8
|
+
When SAM 2 is not installed, a mock implementation is used that
|
|
9
|
+
generates simple masks based on color clustering for testing.
|
|
10
|
+
|
|
11
|
+
Installation:
|
|
12
|
+
pip install segment-anything-2
|
|
13
|
+
|
|
14
|
+
References:
|
|
15
|
+
https://github.com/facebookresearch/segment-anything-2
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Dict, List, Optional, Tuple
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import numpy as np
|
|
26
|
+
NUMPY_AVAILABLE = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
NUMPY_AVAILABLE = False
|
|
29
|
+
np = None
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
import torch
|
|
33
|
+
from sam2.build_sam import build_sam2_video_predictor
|
|
34
|
+
SAM2_AVAILABLE = True
|
|
35
|
+
except ImportError:
|
|
36
|
+
SAM2_AVAILABLE = False
|
|
37
|
+
torch = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SAM2VideoPredictor:
|
|
41
|
+
"""
|
|
42
|
+
Wrapper for SAM 2 video segmentation.
|
|
43
|
+
|
|
44
|
+
Provides temporal mask propagation from initial point prompts.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, predictor, device: str = "cpu"):
|
|
48
|
+
"""
|
|
49
|
+
Initialize predictor wrapper.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
predictor: SAM 2 video predictor instance
|
|
53
|
+
device: Compute device
|
|
54
|
+
"""
|
|
55
|
+
self.predictor = predictor
|
|
56
|
+
self.device = device
|
|
57
|
+
self._state = None
|
|
58
|
+
|
|
59
|
+
def initialize(self, video_path: str) -> None:
|
|
60
|
+
"""
|
|
61
|
+
Initialize predictor with a video.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
video_path: Path to video file
|
|
65
|
+
"""
|
|
66
|
+
import cv2
|
|
67
|
+
import tempfile
|
|
68
|
+
import shutil
|
|
69
|
+
from pathlib import Path
|
|
70
|
+
|
|
71
|
+
# Extract frames using OpenCV (works on Apple Silicon)
|
|
72
|
+
# SAM 2's init_state can use an image directory instead of video
|
|
73
|
+
self._temp_dir = tempfile.mkdtemp(prefix="sam2_frames_")
|
|
74
|
+
|
|
75
|
+
cap = cv2.VideoCapture(video_path)
|
|
76
|
+
frame_idx = 0
|
|
77
|
+
while True:
|
|
78
|
+
ret, frame = cap.read()
|
|
79
|
+
if not ret:
|
|
80
|
+
break
|
|
81
|
+
# Save as JPEG with zero-padded frame numbers
|
|
82
|
+
frame_path = Path(self._temp_dir) / f"{frame_idx:06d}.jpg"
|
|
83
|
+
cv2.imwrite(str(frame_path), frame)
|
|
84
|
+
frame_idx += 1
|
|
85
|
+
cap.release()
|
|
86
|
+
|
|
87
|
+
logger.info(f"Extracted {frame_idx} frames to {self._temp_dir}")
|
|
88
|
+
|
|
89
|
+
# Use image directory mode instead of video mode
|
|
90
|
+
self._state = self.predictor.init_state(video_path=self._temp_dir)
|
|
91
|
+
logger.info(f"SAM 2 initialized with {frame_idx} frames")
|
|
92
|
+
|
|
93
|
+
def add_point_prompt(
|
|
94
|
+
self,
|
|
95
|
+
frame_idx: int,
|
|
96
|
+
obj_id: int,
|
|
97
|
+
point: Tuple[float, float],
|
|
98
|
+
label: int = 1,
|
|
99
|
+
) -> "np.ndarray":
|
|
100
|
+
"""
|
|
101
|
+
Add a point prompt to initialize tracking.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
frame_idx: Frame index (usually 0)
|
|
105
|
+
obj_id: Unique object ID for this link
|
|
106
|
+
point: (x, y) coordinates in frame
|
|
107
|
+
label: 1 for foreground, 0 for background
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Initial mask for the object
|
|
111
|
+
"""
|
|
112
|
+
points = np.array([[point[0], point[1]]], dtype=np.float32)
|
|
113
|
+
labels = np.array([label], dtype=np.int32)
|
|
114
|
+
|
|
115
|
+
_, _, mask_logits = self.predictor.add_new_points_or_box(
|
|
116
|
+
inference_state=self._state,
|
|
117
|
+
frame_idx=frame_idx,
|
|
118
|
+
obj_id=obj_id,
|
|
119
|
+
points=points,
|
|
120
|
+
labels=labels,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Convert logits to binary mask
|
|
124
|
+
mask = (mask_logits > 0).cpu().numpy().squeeze()
|
|
125
|
+
return mask
|
|
126
|
+
|
|
127
|
+
def propagate(self) -> Dict[int, Dict[int, "np.ndarray"]]:
|
|
128
|
+
"""
|
|
129
|
+
Propagate masks across all video frames.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Dict mapping frame_idx -> {obj_id: mask}
|
|
133
|
+
"""
|
|
134
|
+
results = {}
|
|
135
|
+
|
|
136
|
+
for frame_idx, obj_ids, mask_logits in self.predictor.propagate_in_video(
|
|
137
|
+
self._state
|
|
138
|
+
):
|
|
139
|
+
frame_masks = {}
|
|
140
|
+
for i, obj_id in enumerate(obj_ids):
|
|
141
|
+
mask = (mask_logits[i] > 0).cpu().numpy().squeeze()
|
|
142
|
+
frame_masks[obj_id] = mask
|
|
143
|
+
results[frame_idx] = frame_masks
|
|
144
|
+
|
|
145
|
+
logger.info(f"SAM 2 propagated masks across {len(results)} frames")
|
|
146
|
+
return results
|
|
147
|
+
|
|
148
|
+
def reset(self) -> None:
|
|
149
|
+
"""Reset predictor state."""
|
|
150
|
+
if self._state is not None:
|
|
151
|
+
self.predictor.reset_state(self._state)
|
|
152
|
+
self._state = None
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class MockSAM2Predictor:
|
|
156
|
+
"""
|
|
157
|
+
Mock SAM 2 predictor for testing without the real model.
|
|
158
|
+
|
|
159
|
+
Uses simple color-based region growing from seed points.
|
|
160
|
+
This is NOT accurate but allows pipeline testing.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
def __init__(self, device: str = "cpu"):
|
|
164
|
+
"""Initialize mock predictor."""
|
|
165
|
+
self.device = device
|
|
166
|
+
self._video_path: Optional[str] = None
|
|
167
|
+
self._prompts: Dict[int, Tuple[float, float]] = {} # obj_id -> point
|
|
168
|
+
self._frame_count: int = 0
|
|
169
|
+
self._resolution: Tuple[int, int] = (720, 1280)
|
|
170
|
+
|
|
171
|
+
def initialize(self, video_path: str) -> None:
|
|
172
|
+
"""Initialize with video."""
|
|
173
|
+
import cv2
|
|
174
|
+
|
|
175
|
+
self._video_path = video_path
|
|
176
|
+
cap = cv2.VideoCapture(video_path)
|
|
177
|
+
self._frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
178
|
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
179
|
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
180
|
+
self._resolution = (height, width)
|
|
181
|
+
cap.release()
|
|
182
|
+
|
|
183
|
+
logger.info(f"Mock SAM 2 initialized: {self._frame_count} frames")
|
|
184
|
+
|
|
185
|
+
def add_point_prompt(
|
|
186
|
+
self,
|
|
187
|
+
frame_idx: int,
|
|
188
|
+
obj_id: int,
|
|
189
|
+
point: Tuple[float, float],
|
|
190
|
+
label: int = 1,
|
|
191
|
+
) -> "np.ndarray":
|
|
192
|
+
"""Add point prompt and return initial mask."""
|
|
193
|
+
self._prompts[obj_id] = point
|
|
194
|
+
|
|
195
|
+
# Generate simple circular mask around point
|
|
196
|
+
mask = self._generate_circular_mask(point, radius=50)
|
|
197
|
+
return mask
|
|
198
|
+
|
|
199
|
+
def _generate_circular_mask(
|
|
200
|
+
self,
|
|
201
|
+
center: Tuple[float, float],
|
|
202
|
+
radius: int = 50,
|
|
203
|
+
) -> "np.ndarray":
|
|
204
|
+
"""Generate a circular mask around a point."""
|
|
205
|
+
h, w = self._resolution
|
|
206
|
+
mask = np.zeros((h, w), dtype=bool)
|
|
207
|
+
|
|
208
|
+
cx, cy = int(center[0]), int(center[1])
|
|
209
|
+
y, x = np.ogrid[:h, :w]
|
|
210
|
+
dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
|
|
211
|
+
mask[dist <= radius] = True
|
|
212
|
+
|
|
213
|
+
return mask
|
|
214
|
+
|
|
215
|
+
def propagate(self) -> Dict[int, Dict[int, "np.ndarray"]]:
|
|
216
|
+
"""
|
|
217
|
+
Generate mock masks for all frames.
|
|
218
|
+
|
|
219
|
+
Creates slightly varying masks to simulate motion tracking.
|
|
220
|
+
"""
|
|
221
|
+
results = {}
|
|
222
|
+
|
|
223
|
+
for frame_idx in range(self._frame_count):
|
|
224
|
+
frame_masks = {}
|
|
225
|
+
|
|
226
|
+
for obj_id, point in self._prompts.items():
|
|
227
|
+
# Add small random offset to simulate motion
|
|
228
|
+
offset_x = np.random.uniform(-5, 5)
|
|
229
|
+
offset_y = np.random.uniform(-5, 5)
|
|
230
|
+
moved_point = (point[0] + offset_x, point[1] + offset_y)
|
|
231
|
+
|
|
232
|
+
# Vary radius slightly
|
|
233
|
+
radius = 50 + np.random.randint(-10, 10)
|
|
234
|
+
mask = self._generate_circular_mask(moved_point, radius)
|
|
235
|
+
frame_masks[obj_id] = mask
|
|
236
|
+
|
|
237
|
+
results[frame_idx] = frame_masks
|
|
238
|
+
|
|
239
|
+
logger.info(f"Mock SAM 2 generated masks for {len(results)} frames")
|
|
240
|
+
return results
|
|
241
|
+
|
|
242
|
+
def reset(self) -> None:
|
|
243
|
+
"""Reset state."""
|
|
244
|
+
self._prompts = {}
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def load_sam2_predictor(device: str = "cpu") -> SAM2VideoPredictor:
|
|
248
|
+
"""
|
|
249
|
+
Load SAM 2 video predictor.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
device: Compute device ("cuda" or "cpu")
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
SAM2VideoPredictor instance (or mock if unavailable)
|
|
256
|
+
"""
|
|
257
|
+
if SAM2_AVAILABLE:
|
|
258
|
+
try:
|
|
259
|
+
# Try to load the real model
|
|
260
|
+
logger.info("Loading SAM 2 model...")
|
|
261
|
+
|
|
262
|
+
# SAM 2 model checkpoint paths
|
|
263
|
+
# Users should download from:
|
|
264
|
+
# https://github.com/facebookresearch/segment-anything-2#model-checkpoints
|
|
265
|
+
checkpoint_paths = [
|
|
266
|
+
Path.home() / ".cache" / "sam2" / "sam2_hiera_large.pt",
|
|
267
|
+
Path.home() / ".cache" / "sam2" / "sam2_hiera_base_plus.pt",
|
|
268
|
+
Path.home() / ".cache" / "sam2" / "sam2_hiera_base.pt",
|
|
269
|
+
Path("./checkpoints/sam2_hiera_large.pt"),
|
|
270
|
+
Path("./checkpoints/sam2_hiera_base_plus.pt"),
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
checkpoint = None
|
|
274
|
+
for path in checkpoint_paths:
|
|
275
|
+
if path.exists():
|
|
276
|
+
checkpoint = str(path)
|
|
277
|
+
break
|
|
278
|
+
|
|
279
|
+
if checkpoint is None:
|
|
280
|
+
logger.warning(
|
|
281
|
+
"SAM 2 checkpoint not found. Using mock predictor. "
|
|
282
|
+
"Download from: https://github.com/facebookresearch/segment-anything-2"
|
|
283
|
+
)
|
|
284
|
+
return MockSAM2Predictor(device)
|
|
285
|
+
|
|
286
|
+
# Select config file based on checkpoint
|
|
287
|
+
if "base_plus" in checkpoint:
|
|
288
|
+
config_file = "sam2_hiera_b+.yaml"
|
|
289
|
+
elif "base" in checkpoint:
|
|
290
|
+
config_file = "sam2_hiera_b.yaml"
|
|
291
|
+
else:
|
|
292
|
+
config_file = "sam2_hiera_l.yaml"
|
|
293
|
+
|
|
294
|
+
predictor = build_sam2_video_predictor(
|
|
295
|
+
config_file=config_file,
|
|
296
|
+
checkpoint=checkpoint,
|
|
297
|
+
device=device,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
return SAM2VideoPredictor(predictor, device)
|
|
301
|
+
|
|
302
|
+
except Exception as e:
|
|
303
|
+
logger.warning(f"Failed to load SAM 2: {e}. Using mock predictor.")
|
|
304
|
+
return MockSAM2Predictor(device)
|
|
305
|
+
|
|
306
|
+
else:
|
|
307
|
+
logger.warning(
|
|
308
|
+
"SAM 2 not installed. Using mock predictor. "
|
|
309
|
+
"Install with: pip install segment-anything-2"
|
|
310
|
+
)
|
|
311
|
+
return MockSAM2Predictor(device)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def check_sam2_available() -> bool:
|
|
315
|
+
"""Check if SAM 2 is available."""
|
|
316
|
+
return SAM2_AVAILABLE
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
__all__ = [
|
|
320
|
+
"SAM2VideoPredictor",
|
|
321
|
+
"MockSAM2Predictor",
|
|
322
|
+
"load_sam2_predictor",
|
|
323
|
+
"check_sam2_available",
|
|
324
|
+
]
|