foodforthought-cli 0.2.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. ate/__init__.py +6 -0
  2. ate/__main__.py +16 -0
  3. ate/auth/__init__.py +1 -0
  4. ate/auth/device_flow.py +141 -0
  5. ate/auth/token_store.py +96 -0
  6. ate/behaviors/__init__.py +12 -0
  7. ate/behaviors/approach.py +399 -0
  8. ate/cli.py +855 -4551
  9. ate/client.py +90 -0
  10. ate/commands/__init__.py +168 -0
  11. ate/commands/auth.py +389 -0
  12. ate/commands/bridge.py +448 -0
  13. ate/commands/data.py +185 -0
  14. ate/commands/deps.py +111 -0
  15. ate/commands/generate.py +384 -0
  16. ate/commands/memory.py +907 -0
  17. ate/commands/parts.py +166 -0
  18. ate/commands/primitive.py +399 -0
  19. ate/commands/protocol.py +288 -0
  20. ate/commands/recording.py +524 -0
  21. ate/commands/repo.py +154 -0
  22. ate/commands/simulation.py +291 -0
  23. ate/commands/skill.py +303 -0
  24. ate/commands/skills.py +487 -0
  25. ate/commands/team.py +147 -0
  26. ate/commands/workflow.py +271 -0
  27. ate/detection/__init__.py +38 -0
  28. ate/detection/base.py +142 -0
  29. ate/detection/color_detector.py +402 -0
  30. ate/detection/trash_detector.py +322 -0
  31. ate/drivers/__init__.py +18 -6
  32. ate/drivers/ble_transport.py +405 -0
  33. ate/drivers/mechdog.py +360 -24
  34. ate/drivers/wifi_camera.py +477 -0
  35. ate/interfaces/__init__.py +16 -0
  36. ate/interfaces/base.py +2 -0
  37. ate/interfaces/sensors.py +247 -0
  38. ate/llm_proxy.py +239 -0
  39. ate/memory/__init__.py +35 -0
  40. ate/memory/cloud.py +244 -0
  41. ate/memory/context.py +269 -0
  42. ate/memory/embeddings.py +184 -0
  43. ate/memory/export.py +26 -0
  44. ate/memory/merge.py +146 -0
  45. ate/memory/migrate/__init__.py +34 -0
  46. ate/memory/migrate/base.py +89 -0
  47. ate/memory/migrate/pipeline.py +189 -0
  48. ate/memory/migrate/sources/__init__.py +13 -0
  49. ate/memory/migrate/sources/chroma.py +170 -0
  50. ate/memory/migrate/sources/pinecone.py +120 -0
  51. ate/memory/migrate/sources/qdrant.py +110 -0
  52. ate/memory/migrate/sources/weaviate.py +160 -0
  53. ate/memory/reranker.py +353 -0
  54. ate/memory/search.py +26 -0
  55. ate/memory/store.py +548 -0
  56. ate/recording/__init__.py +42 -3
  57. ate/recording/session.py +12 -2
  58. ate/recording/visual.py +416 -0
  59. ate/robot/__init__.py +142 -0
  60. ate/robot/agentic_servo.py +856 -0
  61. ate/robot/behaviors.py +493 -0
  62. ate/robot/ble_capture.py +1000 -0
  63. ate/robot/ble_enumerate.py +506 -0
  64. ate/robot/calibration.py +88 -3
  65. ate/robot/calibration_state.py +388 -0
  66. ate/robot/commands.py +143 -11
  67. ate/robot/direction_calibration.py +554 -0
  68. ate/robot/discovery.py +104 -2
  69. ate/robot/llm_system_id.py +654 -0
  70. ate/robot/locomotion_calibration.py +508 -0
  71. ate/robot/marker_generator.py +611 -0
  72. ate/robot/perception.py +502 -0
  73. ate/robot/primitives.py +614 -0
  74. ate/robot/profiles.py +6 -0
  75. ate/robot/registry.py +5 -2
  76. ate/robot/servo_mapper.py +1153 -0
  77. ate/robot/skill_upload.py +285 -3
  78. ate/robot/target_calibration.py +500 -0
  79. ate/robot/teach.py +515 -0
  80. ate/robot/types.py +242 -0
  81. ate/robot/visual_labeler.py +9 -0
  82. ate/robot/visual_servo_loop.py +494 -0
  83. ate/robot/visual_servoing.py +570 -0
  84. ate/robot/visual_system_id.py +906 -0
  85. ate/transports/__init__.py +121 -0
  86. ate/transports/base.py +394 -0
  87. ate/transports/ble.py +405 -0
  88. ate/transports/hybrid.py +444 -0
  89. ate/transports/serial.py +345 -0
  90. ate/urdf/__init__.py +30 -0
  91. ate/urdf/capture.py +582 -0
  92. ate/urdf/cloud.py +491 -0
  93. ate/urdf/collision.py +271 -0
  94. ate/urdf/commands.py +708 -0
  95. ate/urdf/depth.py +360 -0
  96. ate/urdf/inertial.py +312 -0
  97. ate/urdf/kinematics.py +330 -0
  98. ate/urdf/lifting.py +415 -0
  99. ate/urdf/meshing.py +300 -0
  100. ate/urdf/models/__init__.py +110 -0
  101. ate/urdf/models/depth_anything.py +253 -0
  102. ate/urdf/models/sam2.py +324 -0
  103. ate/urdf/motion_analysis.py +396 -0
  104. ate/urdf/pipeline.py +468 -0
  105. ate/urdf/scale.py +256 -0
  106. ate/urdf/scan_session.py +411 -0
  107. ate/urdf/segmentation.py +299 -0
  108. ate/urdf/synthesis.py +319 -0
  109. ate/urdf/topology.py +336 -0
  110. ate/urdf/validation.py +371 -0
  111. {foodforthought_cli-0.2.8.dist-info → foodforthought_cli-0.3.1.dist-info}/METADATA +1 -1
  112. foodforthought_cli-0.3.1.dist-info/RECORD +166 -0
  113. {foodforthought_cli-0.2.8.dist-info → foodforthought_cli-0.3.1.dist-info}/WHEEL +1 -1
  114. foodforthought_cli-0.2.8.dist-info/RECORD +0 -73
  115. {foodforthought_cli-0.2.8.dist-info → foodforthought_cli-0.3.1.dist-info}/entry_points.txt +0 -0
  116. {foodforthought_cli-0.2.8.dist-info → foodforthought_cli-0.3.1.dist-info}/top_level.txt +0 -0
ate/memory/cloud.py ADDED
@@ -0,0 +1,244 @@
1
+ """
2
+ Cloud client for FoodforThought memory API.
3
+
4
+ Handles push (upload), pull (download), list, and delete of .mv2 memory files.
5
+ """
6
+
7
+ import os
8
+ from dataclasses import dataclass
9
+ from typing import List
10
+
11
+ import requests
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Result dataclasses
16
+ # ---------------------------------------------------------------------------
17
+
18
+ @dataclass
19
+ class PushResult:
20
+ id: str
21
+ name: str
22
+ project: str
23
+ size_bytes: int
24
+ url: str
25
+ created_at: str
26
+
27
+
28
+ @dataclass
29
+ class PullResult:
30
+ path: str
31
+ size_bytes: int
32
+ name: str
33
+ project: str
34
+
35
+
36
+ @dataclass
37
+ class MemoryListItem:
38
+ id: str
39
+ name: str
40
+ project: str
41
+ size_bytes: int
42
+ created_at: str
43
+ updated_at: str
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Exceptions
48
+ # ---------------------------------------------------------------------------
49
+
50
+ class CloudError(Exception):
51
+ """Base exception for cloud operations."""
52
+ pass
53
+
54
+
55
+ class CloudAuthError(CloudError):
56
+ """Raised when authentication is missing or invalid."""
57
+ pass
58
+
59
+
60
+ class CloudNotFoundError(CloudError):
61
+ """Raised when a requested resource is not found (404)."""
62
+ pass
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # Client
67
+ # ---------------------------------------------------------------------------
68
+
69
+ class CloudClient:
70
+ """Client for FoodforThought memory cloud API."""
71
+
72
+ def __init__(self, server_url: str = "https://kindly.fyi", token: str = None):
73
+ self.server_url = server_url.rstrip("/")
74
+ self.token = token
75
+
76
+ # -- public API --------------------------------------------------------
77
+
78
+ def push(self, local_path: str, project: str, name: str = None) -> PushResult:
79
+ """Upload .mv2 file to cloud.
80
+
81
+ POST /api/memory/upload (multipart/form-data)
82
+
83
+ Args:
84
+ local_path: Path to the local .mv2 file.
85
+ project: Project identifier (e.g. "kindly/memories").
86
+ name: Optional name override; defaults to the file's basename.
87
+
88
+ Returns:
89
+ PushResult with upload details.
90
+
91
+ Raises:
92
+ CloudAuthError: If no token is set.
93
+ CloudError: On server error.
94
+ """
95
+ headers = self._auth_headers()
96
+ resolved_name = name or os.path.basename(local_path)
97
+
98
+ url = f"{self.server_url}/api/memory/upload"
99
+
100
+ with open(local_path, "rb") as f:
101
+ resp = requests.post(
102
+ url,
103
+ headers=headers,
104
+ files={"file": (resolved_name, f)},
105
+ data={"project": project, "name": resolved_name},
106
+ )
107
+
108
+ if resp.status_code == 401:
109
+ raise CloudAuthError("Invalid or expired token")
110
+ if resp.status_code >= 400:
111
+ raise CloudError(f"Push failed ({resp.status_code}): {resp.text}")
112
+
113
+ body = resp.json()
114
+ return PushResult(
115
+ id=body["id"],
116
+ name=body["name"],
117
+ project=body["project"],
118
+ size_bytes=body["size_bytes"],
119
+ url=body["url"],
120
+ created_at=body["created_at"],
121
+ )
122
+
123
+ def pull(self, project: str, name: str, output_path: str) -> PullResult:
124
+ """Download .mv2 file from cloud.
125
+
126
+ GET /api/memory/{project}/{name}
127
+
128
+ Args:
129
+ project: Project identifier.
130
+ name: Memory file name.
131
+ output_path: Local path to write the downloaded file.
132
+
133
+ Returns:
134
+ PullResult with download details.
135
+
136
+ Raises:
137
+ CloudAuthError: If no token is set.
138
+ CloudNotFoundError: If the file is not found (404).
139
+ CloudError: On server error.
140
+ """
141
+ headers = self._auth_headers()
142
+ url = f"{self.server_url}/api/memory/{project}/{name}"
143
+
144
+ resp = requests.get(url, headers=headers)
145
+
146
+ if resp.status_code == 404:
147
+ raise CloudNotFoundError(f"Not found: {project}/{name}")
148
+ if resp.status_code == 401:
149
+ raise CloudAuthError("Invalid or expired token")
150
+ if resp.status_code >= 400:
151
+ raise CloudError(f"Pull failed ({resp.status_code}): {resp.text}")
152
+
153
+ # Write content to disk
154
+ parent = os.path.dirname(output_path)
155
+ if parent:
156
+ os.makedirs(parent, exist_ok=True)
157
+
158
+ with open(output_path, "wb") as f:
159
+ f.write(resp.content)
160
+
161
+ size = len(resp.content)
162
+ return PullResult(
163
+ path=output_path,
164
+ size_bytes=size,
165
+ name=name,
166
+ project=project,
167
+ )
168
+
169
+ def list(self, project: str) -> List[MemoryListItem]:
170
+ """List memory files in a project.
171
+
172
+ GET /api/memory/list?project={project}
173
+
174
+ Args:
175
+ project: Project identifier.
176
+
177
+ Returns:
178
+ List of MemoryListItem.
179
+
180
+ Raises:
181
+ CloudAuthError: If no token is set.
182
+ CloudError: On server error.
183
+ """
184
+ headers = self._auth_headers()
185
+ url = f"{self.server_url}/api/memory/list"
186
+
187
+ resp = requests.get(url, headers=headers, params={"project": project})
188
+
189
+ if resp.status_code == 401:
190
+ raise CloudAuthError("Invalid or expired token")
191
+ if resp.status_code >= 400:
192
+ raise CloudError(f"List failed ({resp.status_code}): {resp.text}")
193
+
194
+ body = resp.json()
195
+ return [
196
+ MemoryListItem(
197
+ id=item["id"],
198
+ name=item["name"],
199
+ project=item["project"],
200
+ size_bytes=item["size_bytes"],
201
+ created_at=item["created_at"],
202
+ updated_at=item["updated_at"],
203
+ )
204
+ for item in body.get("items", [])
205
+ ]
206
+
207
+ def delete(self, project: str, name: str) -> bool:
208
+ """Delete a memory file from cloud.
209
+
210
+ DELETE /api/memory/{project}/{name}
211
+
212
+ Args:
213
+ project: Project identifier.
214
+ name: Memory file name.
215
+
216
+ Returns:
217
+ True on success.
218
+
219
+ Raises:
220
+ CloudAuthError: If no token is set.
221
+ CloudNotFoundError: If the file is not found (404).
222
+ CloudError: On server error.
223
+ """
224
+ headers = self._auth_headers()
225
+ url = f"{self.server_url}/api/memory/{project}/{name}"
226
+
227
+ resp = requests.delete(url, headers=headers)
228
+
229
+ if resp.status_code == 404:
230
+ raise CloudNotFoundError(f"Not found: {project}/{name}")
231
+ if resp.status_code == 401:
232
+ raise CloudAuthError("Invalid or expired token")
233
+ if resp.status_code >= 400:
234
+ raise CloudError(f"Delete failed ({resp.status_code}): {resp.text}")
235
+
236
+ return True
237
+
238
+ # -- internal ----------------------------------------------------------
239
+
240
+ def _auth_headers(self) -> dict:
241
+ """Build authorization headers."""
242
+ if not self.token:
243
+ raise CloudAuthError("Not authenticated. Run: ate device-login")
244
+ return {"Authorization": f"Bearer {self.token}"}
ate/memory/context.py ADDED
@@ -0,0 +1,269 @@
1
+ """Context management for git-like memory operations.
2
+
3
+ This module provides the ContextManager that tracks the active memory and train
4
+ of thought, similar to how git tracks the current repository and branch.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import re
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import List, Optional, Dict, Any
13
+
14
+ from .store import MemoryStore
15
+
16
+
17
+ @dataclass
18
+ class MemoryContext:
19
+ """Tracks active memory and train of thought."""
20
+ active_memory: str
21
+ active_train: str
22
+ path: str
23
+
24
+
25
+ @dataclass
26
+ class MemoryMetadata:
27
+ """Metadata for a memory."""
28
+ name: str
29
+ visibility: str
30
+ trains: List[str]
31
+ default_train: str
32
+ description: str = ""
33
+ created_at: Optional[str] = None
34
+ remote: Optional[str] = None
35
+
36
+
37
+ class ContextManager:
38
+ """Manages the active memory context (~/.ate/context.json)."""
39
+
40
+ CONFIG_DIR = os.path.expanduser("~/.ate")
41
+ CONTEXT_FILE = os.path.expanduser("~/.ate/context.json")
42
+ MEMORIES_DIR = os.path.expanduser("~/.ate/memories")
43
+
44
+ @classmethod
45
+ def get_context(cls) -> MemoryContext:
46
+ """Get current context. Auto-initializes on first use."""
47
+ if os.path.exists(cls.CONTEXT_FILE):
48
+ try:
49
+ with open(cls.CONTEXT_FILE, 'r') as f:
50
+ data = json.load(f)
51
+
52
+ # Provide defaults for missing fields
53
+ active_memory = data.get("active_memory", "default")
54
+ active_train = data.get("active_train", "main")
55
+ path = data.get("path")
56
+
57
+ if not path:
58
+ path = cls._train_to_path(active_memory, active_train)
59
+
60
+ return MemoryContext(
61
+ active_memory=active_memory,
62
+ active_train=active_train,
63
+ path=path
64
+ )
65
+ except (json.JSONDecodeError, KeyError):
66
+ pass # Fall through to auto-init
67
+
68
+ # Auto-initialize on first use
69
+ return cls._auto_initialize()
70
+
71
+ @classmethod
72
+ def set_context(cls, memory: str, train: str) -> MemoryContext:
73
+ """Set active context."""
74
+ # Validate memory name
75
+ if not cls._is_valid_memory_name(memory):
76
+ raise ValueError(f"Invalid memory name '{memory}'. Use lowercase alphanumeric and hyphens only.")
77
+
78
+ # Ensure directories exist
79
+ os.makedirs(cls.CONFIG_DIR, exist_ok=True)
80
+ os.makedirs(cls.MEMORIES_DIR, exist_ok=True)
81
+ memory_dir = os.path.join(cls.MEMORIES_DIR, memory)
82
+ os.makedirs(memory_dir, exist_ok=True)
83
+
84
+ # Create path
85
+ path = cls._train_to_path(memory, train)
86
+
87
+ # Create .mv2 file if it doesn't exist (auto-create new trains)
88
+ if not os.path.exists(path):
89
+ try:
90
+ store = MemoryStore.create(path)
91
+ store.close()
92
+ except Exception:
93
+ pass
94
+ # Ensure file exists on disk (memvid create may not touch filesystem)
95
+ if not os.path.exists(path):
96
+ Path(path).touch()
97
+
98
+ # Update memory.json trains list
99
+ safe_train = train.replace('/', '-')
100
+ memory_json = os.path.join(memory_dir, "memory.json")
101
+ if os.path.exists(memory_json):
102
+ try:
103
+ with open(memory_json, 'r') as f:
104
+ data = json.load(f)
105
+ if safe_train not in data.get("trains", []):
106
+ data.setdefault("trains", []).append(safe_train)
107
+ with open(memory_json, 'w') as f:
108
+ json.dump(data, f, indent=2)
109
+ except (json.JSONDecodeError, KeyError):
110
+ pass
111
+
112
+ # Create context
113
+ context = MemoryContext(
114
+ active_memory=memory,
115
+ active_train=train,
116
+ path=path
117
+ )
118
+
119
+ # Write context file
120
+ context_data = {
121
+ "active_memory": memory,
122
+ "active_train": train,
123
+ "path": path
124
+ }
125
+
126
+ with open(cls.CONTEXT_FILE, 'w') as f:
127
+ json.dump(context_data, f, indent=2)
128
+
129
+ return context
130
+
131
+ @classmethod
132
+ def resolve_path(cls, memory: Optional[str] = None, train: Optional[str] = None) -> str:
133
+ """Resolve .mv2 path from context or explicit args."""
134
+ if memory is not None and train is not None:
135
+ return cls._train_to_path(memory, train)
136
+
137
+ context = cls.get_context()
138
+ if memory is not None:
139
+ return cls._train_to_path(memory, context.active_train)
140
+ if train is not None:
141
+ return cls._train_to_path(context.active_memory, train)
142
+
143
+ return context.path
144
+
145
+ @classmethod
146
+ def ensure_memory(cls, name: str) -> str:
147
+ """Create memory dir + default train if doesn't exist."""
148
+ if not cls._is_valid_memory_name(name):
149
+ raise ValueError(f"Invalid memory name '{name}'. Use lowercase alphanumeric and hyphens only.")
150
+
151
+ # Create config and memories directories
152
+ os.makedirs(cls.CONFIG_DIR, exist_ok=True)
153
+ os.makedirs(cls.MEMORIES_DIR, exist_ok=True)
154
+
155
+ # Create memory directory
156
+ memory_dir = os.path.join(cls.MEMORIES_DIR, name)
157
+ os.makedirs(memory_dir, exist_ok=True)
158
+
159
+ # Create default train (main.mv2) if it doesn't exist
160
+ main_path = cls._train_to_path(name, "main")
161
+ if not os.path.exists(main_path):
162
+ try:
163
+ store = MemoryStore.create(main_path)
164
+ store.close()
165
+ except Exception:
166
+ pass
167
+ # Ensure file exists on disk
168
+ if not os.path.exists(main_path):
169
+ Path(main_path).touch()
170
+
171
+ # Create memory.json if it doesn't exist
172
+ memory_json_path = os.path.join(memory_dir, "memory.json")
173
+ if not os.path.exists(memory_json_path):
174
+ metadata = {
175
+ "name": name,
176
+ "visibility": "private",
177
+ "trains": ["main"],
178
+ "default_train": "main",
179
+ "description": ""
180
+ }
181
+ with open(memory_json_path, 'w') as f:
182
+ json.dump(metadata, f, indent=2)
183
+
184
+ return main_path
185
+
186
+ @classmethod
187
+ def list_memories(cls) -> List[MemoryMetadata]:
188
+ """List all local memories."""
189
+ memories = []
190
+
191
+ if not os.path.exists(cls.MEMORIES_DIR):
192
+ return memories
193
+
194
+ for item in os.listdir(cls.MEMORIES_DIR):
195
+ memory_dir = os.path.join(cls.MEMORIES_DIR, item)
196
+ if os.path.isdir(memory_dir):
197
+ memory_json_path = os.path.join(memory_dir, "memory.json")
198
+ if os.path.exists(memory_json_path):
199
+ try:
200
+ with open(memory_json_path, 'r') as f:
201
+ data = json.load(f)
202
+
203
+ metadata = MemoryMetadata(
204
+ name=data["name"],
205
+ visibility=data.get("visibility", "private"),
206
+ trains=data.get("trains", ["main"]),
207
+ default_train=data.get("default_train", "main"),
208
+ description=data.get("description", ""),
209
+ created_at=data.get("created_at"),
210
+ remote=data.get("remote")
211
+ )
212
+ memories.append(metadata)
213
+ except (json.JSONDecodeError, KeyError):
214
+ continue # Skip malformed memory.json files
215
+
216
+ return memories
217
+
218
+ @classmethod
219
+ def list_trains(cls, memory: Optional[str] = None) -> List[str]:
220
+ """List trains of thought in a memory."""
221
+ if memory is None:
222
+ context = cls.get_context()
223
+ memory = context.active_memory
224
+
225
+ memory_dir = os.path.join(cls.MEMORIES_DIR, memory)
226
+ if not os.path.exists(memory_dir):
227
+ raise FileNotFoundError(f"Memory '{memory}' does not exist")
228
+
229
+ trains = set()
230
+
231
+ # Source 1: .mv2 files on disk
232
+ for item in os.listdir(memory_dir):
233
+ if item.endswith('.mv2'):
234
+ train_name = item[:-4] # Remove .mv2 extension
235
+ trains.add(train_name)
236
+
237
+ # Source 2: memory.json trains list (in case files haven't been created yet)
238
+ memory_json = os.path.join(memory_dir, "memory.json")
239
+ if os.path.exists(memory_json):
240
+ try:
241
+ with open(memory_json, 'r') as f:
242
+ data = json.load(f)
243
+ for t in data.get("trains", []):
244
+ trains.add(t)
245
+ except (json.JSONDecodeError, KeyError):
246
+ pass
247
+
248
+ return sorted(trains)
249
+
250
+ @classmethod
251
+ def _auto_initialize(cls) -> MemoryContext:
252
+ """Auto-initialize default memory on first use."""
253
+ # Ensure default memory exists
254
+ cls.ensure_memory("default")
255
+
256
+ # Set context to default/main
257
+ return cls.set_context("default", "main")
258
+
259
+ @classmethod
260
+ def _train_to_path(cls, memory: str, train: str) -> str:
261
+ """Convert memory + train to .mv2 file path."""
262
+ # Convert train name to filename-safe format (slashes to hyphens)
263
+ safe_train = train.replace('/', '-')
264
+ return os.path.join(cls.MEMORIES_DIR, memory, f"{safe_train}.mv2")
265
+
266
+ @classmethod
267
+ def _is_valid_memory_name(cls, name: str) -> bool:
268
+ """Check if memory name is valid (alphanumeric + hyphens only)."""
269
+ return bool(re.match(r'^[a-z0-9-]+$', name))
@@ -0,0 +1,184 @@
1
+ """Embedding configuration and management for ate memory."""
2
+
3
+ import os
4
+ import requests
5
+ from dataclasses import dataclass
6
+ from typing import Optional, List, Dict, Any
7
+
8
+ import memvid_sdk
9
+
10
+
11
+ @dataclass
12
+ class EmbeddingConfig:
13
+ """Embedding provider configuration."""
14
+ provider: str = "none"
15
+ model: Optional[str] = None
16
+ api_key: Optional[str] = None
17
+
18
+ def __post_init__(self):
19
+ """Validate provider after initialization."""
20
+ valid_providers = ["openai", "cohere", "voyage", "ollama", "none"]
21
+ if self.provider not in valid_providers:
22
+ raise ValueError(f"Invalid provider: {self.provider}. Must be one of {valid_providers}")
23
+
24
+ # Set default models based on provider
25
+ if self.model is None:
26
+ if self.provider == "openai":
27
+ self.model = "text-embedding-3-small"
28
+ elif self.provider == "cohere":
29
+ self.model = "embed-english-v3.0"
30
+ elif self.provider == "voyage":
31
+ self.model = "voyage-2"
32
+ elif self.provider == "ollama":
33
+ self.model = "nomic-embed-text"
34
+
35
+
36
+ class EmbeddingManager:
37
+ """Detects and manages embedding providers for ate memory."""
38
+
39
+ @staticmethod
40
+ def detect() -> EmbeddingConfig:
41
+ """Auto-detect best available embedding provider from env.
42
+
43
+ Detection order: OpenAI → Cohere → Voyage → Ollama → BM25-only
44
+
45
+ Returns:
46
+ EmbeddingConfig with detected provider and settings
47
+ """
48
+ # Check OpenAI first (highest priority)
49
+ openai_key = os.environ.get('OPENAI_API_KEY')
50
+ if openai_key:
51
+ return EmbeddingConfig(
52
+ provider="openai",
53
+ api_key=openai_key
54
+ )
55
+
56
+ # Check Cohere second
57
+ cohere_key = os.environ.get('COHERE_API_KEY')
58
+ if cohere_key:
59
+ return EmbeddingConfig(
60
+ provider="cohere",
61
+ api_key=cohere_key
62
+ )
63
+
64
+ # Check Voyage third
65
+ voyage_key = os.environ.get('VOYAGE_API_KEY')
66
+ if voyage_key:
67
+ return EmbeddingConfig(
68
+ provider="voyage",
69
+ api_key=voyage_key
70
+ )
71
+
72
+ # Check Ollama fourth (local service)
73
+ ollama_host = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
74
+ if EmbeddingManager._is_ollama_available(ollama_host):
75
+ return EmbeddingConfig(
76
+ provider="ollama",
77
+ model="nomic-embed-text"
78
+ )
79
+
80
+ # No providers available
81
+ return EmbeddingConfig(provider="none")
82
+
83
+ @staticmethod
84
+ def _is_ollama_available(host: str) -> bool:
85
+ """Check if Ollama is reachable at the given host."""
86
+ try:
87
+ response = requests.get(f'{host}/api/tags', timeout=2)
88
+ return response.status_code == 200
89
+ except:
90
+ return False
91
+
92
+ @staticmethod
93
+ def get_provider(config: EmbeddingConfig):
94
+ """Get a memvid_sdk EmbeddingProvider from config.
95
+
96
+ Args:
97
+ config: EmbeddingConfig instance
98
+
99
+ Returns:
100
+ EmbeddingProvider instance or None if provider is "none"
101
+ """
102
+ if config.provider == "none":
103
+ return None
104
+
105
+ return memvid_sdk.embeddings.get_embedder(
106
+ provider=config.provider,
107
+ model=config.model,
108
+ api_key=config.api_key
109
+ )
110
+
111
+ @staticmethod
112
+ def available_providers() -> List[Dict[str, Any]]:
113
+ """List all detected providers with status.
114
+
115
+ Returns:
116
+ List of provider status dictionaries
117
+ """
118
+ providers = []
119
+
120
+ # Check OpenAI
121
+ openai_key = os.environ.get('OPENAI_API_KEY')
122
+ if openai_key:
123
+ providers.append({
124
+ "name": "openai",
125
+ "available": True,
126
+ "model": "text-embedding-3-small",
127
+ "source": "OPENAI_API_KEY"
128
+ })
129
+ else:
130
+ providers.append({
131
+ "name": "openai",
132
+ "available": False,
133
+ "reason": "OPENAI_API_KEY not set"
134
+ })
135
+
136
+ # Check Cohere
137
+ cohere_key = os.environ.get('COHERE_API_KEY')
138
+ if cohere_key:
139
+ providers.append({
140
+ "name": "cohere",
141
+ "available": True,
142
+ "model": "embed-english-v3.0",
143
+ "source": "COHERE_API_KEY"
144
+ })
145
+ else:
146
+ providers.append({
147
+ "name": "cohere",
148
+ "available": False,
149
+ "reason": "COHERE_API_KEY not set"
150
+ })
151
+
152
+ # Check Voyage
153
+ voyage_key = os.environ.get('VOYAGE_API_KEY')
154
+ if voyage_key:
155
+ providers.append({
156
+ "name": "voyage",
157
+ "available": True,
158
+ "model": "voyage-2",
159
+ "source": "VOYAGE_API_KEY"
160
+ })
161
+ else:
162
+ providers.append({
163
+ "name": "voyage",
164
+ "available": False,
165
+ "reason": "VOYAGE_API_KEY not set"
166
+ })
167
+
168
+ # Check Ollama
169
+ ollama_host = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
170
+ if EmbeddingManager._is_ollama_available(ollama_host):
171
+ providers.append({
172
+ "name": "ollama",
173
+ "available": True,
174
+ "model": "nomic-embed-text",
175
+ "source": ollama_host
176
+ })
177
+ else:
178
+ providers.append({
179
+ "name": "ollama",
180
+ "available": False,
181
+ "reason": "OLLAMA_HOST not set, localhost:11434 not reachable"
182
+ })
183
+
184
+ return providers