foodforthought-cli 0.2.7__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. ate/__init__.py +6 -0
  2. ate/__main__.py +16 -0
  3. ate/auth/__init__.py +1 -0
  4. ate/auth/device_flow.py +141 -0
  5. ate/auth/token_store.py +96 -0
  6. ate/behaviors/__init__.py +100 -0
  7. ate/behaviors/approach.py +399 -0
  8. ate/behaviors/common.py +686 -0
  9. ate/behaviors/tree.py +454 -0
  10. ate/cli.py +855 -3995
  11. ate/client.py +90 -0
  12. ate/commands/__init__.py +168 -0
  13. ate/commands/auth.py +389 -0
  14. ate/commands/bridge.py +448 -0
  15. ate/commands/data.py +185 -0
  16. ate/commands/deps.py +111 -0
  17. ate/commands/generate.py +384 -0
  18. ate/commands/memory.py +907 -0
  19. ate/commands/parts.py +166 -0
  20. ate/commands/primitive.py +399 -0
  21. ate/commands/protocol.py +288 -0
  22. ate/commands/recording.py +524 -0
  23. ate/commands/repo.py +154 -0
  24. ate/commands/simulation.py +291 -0
  25. ate/commands/skill.py +303 -0
  26. ate/commands/skills.py +487 -0
  27. ate/commands/team.py +147 -0
  28. ate/commands/workflow.py +271 -0
  29. ate/detection/__init__.py +38 -0
  30. ate/detection/base.py +142 -0
  31. ate/detection/color_detector.py +399 -0
  32. ate/detection/trash_detector.py +322 -0
  33. ate/drivers/__init__.py +39 -0
  34. ate/drivers/ble_transport.py +405 -0
  35. ate/drivers/mechdog.py +942 -0
  36. ate/drivers/wifi_camera.py +477 -0
  37. ate/interfaces/__init__.py +187 -0
  38. ate/interfaces/base.py +273 -0
  39. ate/interfaces/body.py +267 -0
  40. ate/interfaces/detection.py +282 -0
  41. ate/interfaces/locomotion.py +422 -0
  42. ate/interfaces/manipulation.py +408 -0
  43. ate/interfaces/navigation.py +389 -0
  44. ate/interfaces/perception.py +362 -0
  45. ate/interfaces/sensors.py +247 -0
  46. ate/interfaces/types.py +371 -0
  47. ate/llm_proxy.py +239 -0
  48. ate/mcp_server.py +387 -0
  49. ate/memory/__init__.py +35 -0
  50. ate/memory/cloud.py +244 -0
  51. ate/memory/context.py +269 -0
  52. ate/memory/embeddings.py +184 -0
  53. ate/memory/export.py +26 -0
  54. ate/memory/merge.py +146 -0
  55. ate/memory/migrate/__init__.py +34 -0
  56. ate/memory/migrate/base.py +89 -0
  57. ate/memory/migrate/pipeline.py +189 -0
  58. ate/memory/migrate/sources/__init__.py +13 -0
  59. ate/memory/migrate/sources/chroma.py +170 -0
  60. ate/memory/migrate/sources/pinecone.py +120 -0
  61. ate/memory/migrate/sources/qdrant.py +110 -0
  62. ate/memory/migrate/sources/weaviate.py +160 -0
  63. ate/memory/reranker.py +353 -0
  64. ate/memory/search.py +26 -0
  65. ate/memory/store.py +548 -0
  66. ate/recording/__init__.py +83 -0
  67. ate/recording/demonstration.py +378 -0
  68. ate/recording/session.py +415 -0
  69. ate/recording/upload.py +304 -0
  70. ate/recording/visual.py +416 -0
  71. ate/recording/wrapper.py +95 -0
  72. ate/robot/__init__.py +221 -0
  73. ate/robot/agentic_servo.py +856 -0
  74. ate/robot/behaviors.py +493 -0
  75. ate/robot/ble_capture.py +1000 -0
  76. ate/robot/ble_enumerate.py +506 -0
  77. ate/robot/calibration.py +668 -0
  78. ate/robot/calibration_state.py +388 -0
  79. ate/robot/commands.py +3735 -0
  80. ate/robot/direction_calibration.py +554 -0
  81. ate/robot/discovery.py +441 -0
  82. ate/robot/introspection.py +330 -0
  83. ate/robot/llm_system_id.py +654 -0
  84. ate/robot/locomotion_calibration.py +508 -0
  85. ate/robot/manager.py +270 -0
  86. ate/robot/marker_generator.py +611 -0
  87. ate/robot/perception.py +502 -0
  88. ate/robot/primitives.py +614 -0
  89. ate/robot/profiles.py +281 -0
  90. ate/robot/registry.py +322 -0
  91. ate/robot/servo_mapper.py +1153 -0
  92. ate/robot/skill_upload.py +675 -0
  93. ate/robot/target_calibration.py +500 -0
  94. ate/robot/teach.py +515 -0
  95. ate/robot/types.py +242 -0
  96. ate/robot/visual_labeler.py +1048 -0
  97. ate/robot/visual_servo_loop.py +494 -0
  98. ate/robot/visual_servoing.py +570 -0
  99. ate/robot/visual_system_id.py +906 -0
  100. ate/transports/__init__.py +121 -0
  101. ate/transports/base.py +394 -0
  102. ate/transports/ble.py +405 -0
  103. ate/transports/hybrid.py +444 -0
  104. ate/transports/serial.py +345 -0
  105. ate/urdf/__init__.py +30 -0
  106. ate/urdf/capture.py +582 -0
  107. ate/urdf/cloud.py +491 -0
  108. ate/urdf/collision.py +271 -0
  109. ate/urdf/commands.py +708 -0
  110. ate/urdf/depth.py +360 -0
  111. ate/urdf/inertial.py +312 -0
  112. ate/urdf/kinematics.py +330 -0
  113. ate/urdf/lifting.py +415 -0
  114. ate/urdf/meshing.py +300 -0
  115. ate/urdf/models/__init__.py +110 -0
  116. ate/urdf/models/depth_anything.py +253 -0
  117. ate/urdf/models/sam2.py +324 -0
  118. ate/urdf/motion_analysis.py +396 -0
  119. ate/urdf/pipeline.py +468 -0
  120. ate/urdf/scale.py +256 -0
  121. ate/urdf/scan_session.py +411 -0
  122. ate/urdf/segmentation.py +299 -0
  123. ate/urdf/synthesis.py +319 -0
  124. ate/urdf/topology.py +336 -0
  125. ate/urdf/validation.py +371 -0
  126. {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/METADATA +9 -1
  127. foodforthought_cli-0.3.0.dist-info/RECORD +166 -0
  128. {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/WHEEL +1 -1
  129. foodforthought_cli-0.2.7.dist-info/RECORD +0 -44
  130. {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/entry_points.txt +0 -0
  131. {foodforthought_cli-0.2.7.dist-info → foodforthought_cli-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,160 @@
1
+ """Weaviate migration source implementation."""
2
+ from typing import Optional, List, Tuple, Dict, Any
3
+
4
+ try:
5
+ import weaviate
6
+ WEAVIATE_AVAILABLE = True
7
+ except ImportError:
8
+ # Create a simple mock structure for testing
9
+ class WeaviateMock:
10
+ class Client:
11
+ pass
12
+ class auth:
13
+ class AuthApiKey:
14
+ def __init__(self, **kwargs):
15
+ pass
16
+
17
+ weaviate = WeaviateMock()
18
+ WEAVIATE_AVAILABLE = False
19
+
20
+ from ..base import MigrationSource, VectorRecord, MigrationEstimate
21
+
22
+
23
+ class WeaviateMigrationSource(MigrationSource):
24
+ """Migration source for Weaviate vector database."""
25
+
26
+ def __init__(self, host: str, class_name: str, api_key: Optional[str] = None, auth_config: Optional[Dict[str, Any]] = None):
27
+ """Initialize Weaviate migration source.
28
+
29
+ Args:
30
+ host: Weaviate server host URL
31
+ class_name: Name of the class to migrate from
32
+ api_key: Optional API key for authentication
33
+ auth_config: Optional authentication configuration
34
+ """
35
+ self.host = host
36
+ self.class_name = class_name
37
+ self.api_key = api_key
38
+ self.auth_config = auth_config or {}
39
+ self.client = None
40
+
41
+ @property
42
+ def source_type(self) -> str:
43
+ """Return the source type."""
44
+ return "weaviate"
45
+
46
+ @property
47
+ def source_name(self) -> str:
48
+ """Return the source name."""
49
+ return self.class_name
50
+
51
+ def connect(self) -> None:
52
+ """Connect to Weaviate."""
53
+ if weaviate.Client is None:
54
+ raise ImportError("weaviate-client library is required for Weaviate migration")
55
+
56
+ # Build client configuration
57
+ if self.api_key:
58
+ auth = weaviate.auth.AuthApiKey(api_key=self.api_key)
59
+ self.client = weaviate.Client(url=self.host, auth_client_secret=auth)
60
+ else:
61
+ self.client = weaviate.Client(url=self.host)
62
+
63
+ def estimate(self) -> MigrationEstimate:
64
+ """Estimate migration size and time."""
65
+ if not self.client:
66
+ raise RuntimeError("Must call connect() first")
67
+
68
+ # Get aggregate count using fluent API
69
+ result = (self.client.query
70
+ .aggregate(self.class_name)
71
+ .with_meta_count()
72
+ .do())
73
+
74
+ aggregate_data = result.get('data', {}).get('Aggregate', {}).get(self.class_name, [])
75
+
76
+ if aggregate_data:
77
+ total_vectors = aggregate_data[0].get('meta', {}).get('count', 0)
78
+ else:
79
+ total_vectors = 0
80
+
81
+ # Get class schema to determine vector dimensions
82
+ schema = self.client.schema.get()
83
+ dimensions = 768 # Default assumption
84
+
85
+ for class_def in schema.get('classes', []):
86
+ if class_def['class'] == self.class_name:
87
+ # Look for vectorizer configuration
88
+ if 'vectorizer' in class_def:
89
+ # This is a simplified assumption; real implementation would inspect vectorizer config
90
+ dimensions = 1536 if 'openai' in class_def['vectorizer'].lower() else 768
91
+ break
92
+
93
+ # Rough estimates
94
+ bytes_per_vector = dimensions * 4 + 2048 # 4 bytes per float + metadata overhead
95
+ estimated_mv2_bytes = total_vectors * bytes_per_vector
96
+ estimated_seconds = total_vectors / 800.0 # Rough estimate of 800 vectors/second
97
+
98
+ return MigrationEstimate(
99
+ total_vectors=total_vectors,
100
+ dimensions=dimensions,
101
+ estimated_mv2_bytes=estimated_mv2_bytes,
102
+ estimated_seconds=estimated_seconds
103
+ )
104
+
105
+ def fetch_batch(self, batch_size: int = 10000, cursor: Optional[str] = None) -> Tuple[List[VectorRecord], Optional[str]]:
106
+ """Fetch a batch of records from Weaviate."""
107
+ if not self.client:
108
+ raise RuntimeError("Must call connect() first")
109
+
110
+ # Parse cursor as offset if provided
111
+ offset = 0
112
+ if cursor:
113
+ try:
114
+ offset = int(cursor)
115
+ except (ValueError, TypeError):
116
+ offset = 0
117
+
118
+ # Query with additional vector data
119
+ result = (self.client.query
120
+ .get(self.class_name)
121
+ .with_additional(['id', 'vector'])
122
+ .with_limit(batch_size)
123
+ .with_offset(offset)
124
+ .do())
125
+
126
+ if 'errors' in result:
127
+ raise Exception(f"Weaviate query error: {result['errors']}")
128
+
129
+ objects = result.get('data', {}).get('Get', {}).get(self.class_name, [])
130
+
131
+ records = []
132
+ for obj in objects:
133
+ additional = obj.get('_additional', {})
134
+ obj_id = additional.get('id', 'unknown')
135
+ vector = additional.get('vector', [])
136
+
137
+ # Extract text and metadata
138
+ properties = {k: v for k, v in obj.items() if not k.startswith('_')}
139
+ text = properties.get('text') or properties.get('content')
140
+
141
+ # Create clean metadata without text
142
+ clean_metadata = {k: v for k, v in properties.items() if k not in ('text', 'content')}
143
+
144
+ record = VectorRecord(
145
+ id=obj_id,
146
+ vector=vector,
147
+ text=text,
148
+ metadata=clean_metadata
149
+ )
150
+ records.append(record)
151
+
152
+ # Determine if there are more results
153
+ next_cursor = str(offset + batch_size) if len(objects) == batch_size else None
154
+
155
+ return records, next_cursor
156
+
157
+ def close(self) -> None:
158
+ """Close the connection and clean up resources."""
159
+ # Weaviate client doesn't require explicit cleanup
160
+ self.client = None
ate/memory/reranker.py ADDED
@@ -0,0 +1,353 @@
1
+ """LLM re-ranking search engine for ate memory."""
2
+
3
+ import json
4
+ import os
5
+ import requests
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
8
+ from typing import Optional, List
9
+
10
+ from .search import SearchResult
11
+
12
+
13
+ @dataclass
14
+ class RerankConfig:
15
+ """Configuration for LLM re-ranking."""
16
+ provider: str # "anthropic" | "openai" | "google" | "ollama"
17
+ model: Optional[str] = None # Model override (e.g. "claude-haiku-3.5")
18
+ api_key: Optional[str] = None # Explicit key (overrides env)
19
+
20
+ def __post_init__(self):
21
+ """Validate provider and set default models after initialization."""
22
+ valid_providers = ["anthropic", "openai", "google", "ollama"]
23
+ if self.provider not in valid_providers:
24
+ raise ValueError(f"Invalid provider: {self.provider}. Must be one of {valid_providers}")
25
+
26
+ # Set default models based on provider
27
+ if self.model is None:
28
+ if self.provider == "anthropic":
29
+ self.model = "claude-haiku-3.5-latest"
30
+ elif self.provider == "openai":
31
+ self.model = "gpt-4o-mini"
32
+ elif self.provider == "google":
33
+ self.model = "gemini-2.0-flash"
34
+ elif self.provider == "ollama":
35
+ self.model = "llama3.2"
36
+
37
+
38
+ class LLMProvider(ABC):
39
+ """Abstract LLM provider for re-ranking."""
40
+
41
+ @abstractmethod
42
+ def complete(self, prompt: str, max_tokens: int = 200) -> str:
43
+ """Send a completion request and return the text response."""
44
+ pass
45
+
46
+
47
+ class AnthropicProvider(LLMProvider):
48
+ """Anthropic Claude provider."""
49
+
50
+ def __init__(self, model: str, api_key: str):
51
+ self.model = model
52
+ self.api_key = api_key
53
+
54
+ def complete(self, prompt: str, max_tokens: int = 200) -> str:
55
+ """Send a completion request to Anthropic API."""
56
+ headers = {
57
+ 'x-api-key': self.api_key,
58
+ 'anthropic-version': '2023-06-01',
59
+ 'content-type': 'application/json'
60
+ }
61
+
62
+ data = {
63
+ 'model': self.model,
64
+ 'max_tokens': max_tokens,
65
+ 'messages': [{'role': 'user', 'content': prompt}]
66
+ }
67
+
68
+ response = requests.post(
69
+ 'https://api.anthropic.com/v1/messages',
70
+ headers=headers,
71
+ json=data,
72
+ timeout=5
73
+ )
74
+ response.raise_for_status()
75
+
76
+ result = response.json()
77
+ return result['content'][0]['text']
78
+
79
+
80
+ class OpenAIProvider(LLMProvider):
81
+ """OpenAI provider."""
82
+
83
+ def __init__(self, model: str, api_key: str):
84
+ self.model = model
85
+ self.api_key = api_key
86
+
87
+ def complete(self, prompt: str, max_tokens: int = 200) -> str:
88
+ """Send a completion request to OpenAI API."""
89
+ headers = {
90
+ 'Authorization': f'Bearer {self.api_key}',
91
+ 'content-type': 'application/json'
92
+ }
93
+
94
+ data = {
95
+ 'model': self.model,
96
+ 'max_tokens': max_tokens,
97
+ 'messages': [{'role': 'user', 'content': prompt}]
98
+ }
99
+
100
+ response = requests.post(
101
+ 'https://api.openai.com/v1/chat/completions',
102
+ headers=headers,
103
+ json=data,
104
+ timeout=5
105
+ )
106
+ response.raise_for_status()
107
+
108
+ result = response.json()
109
+ return result['choices'][0]['message']['content']
110
+
111
+
112
+ class GoogleProvider(LLMProvider):
113
+ """Google Gemini provider."""
114
+
115
+ def __init__(self, model: str, api_key: str):
116
+ self.model = model
117
+ self.api_key = api_key
118
+
119
+ def complete(self, prompt: str, max_tokens: int = 200) -> str:
120
+ """Send a completion request to Google Gemini API."""
121
+ headers = {
122
+ 'content-type': 'application/json'
123
+ }
124
+
125
+ data = {
126
+ 'contents': [{'parts': [{'text': prompt}]}],
127
+ 'generationConfig': {'maxOutputTokens': max_tokens}
128
+ }
129
+
130
+ url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent?key={self.api_key}'
131
+ response = requests.post(url, headers=headers, json=data, timeout=5)
132
+ response.raise_for_status()
133
+
134
+ result = response.json()
135
+ return result['candidates'][0]['content']['parts'][0]['text']
136
+
137
+
138
+ class OllamaLLMProvider(LLMProvider):
139
+ """Ollama local provider."""
140
+
141
+ def __init__(self, model: str, host: str = "http://localhost:11434"):
142
+ self.model = model
143
+ self.host = host
144
+
145
+ def complete(self, prompt: str, max_tokens: int = 200) -> str:
146
+ """Send a completion request to Ollama API."""
147
+ headers = {
148
+ 'content-type': 'application/json'
149
+ }
150
+
151
+ data = {
152
+ 'model': self.model,
153
+ 'prompt': prompt,
154
+ 'stream': False
155
+ }
156
+
157
+ response = requests.post(
158
+ f'{self.host}/api/generate',
159
+ headers=headers,
160
+ json=data,
161
+ timeout=30 # Ollama can be slower
162
+ )
163
+ response.raise_for_status()
164
+
165
+ result = response.json()
166
+ return result['response']
167
+
168
+
169
+ class LLMReranker:
170
+ """Re-ranks BM25 search results using an LLM for semantic understanding."""
171
+
172
+ def __init__(self, config: RerankConfig):
173
+ """Initialize with LLM provider config."""
174
+ self.config = config
175
+
176
+ @staticmethod
177
+ def detect() -> Optional[RerankConfig]:
178
+ """Auto-detect LLM provider from env vars.
179
+ Detection order: ANTHROPIC_API_KEY → OPENAI_API_KEY → GOOGLE_API_KEY → Ollama
180
+ """
181
+ # Check Anthropic first (highest priority)
182
+ anthropic_key = os.environ.get('ANTHROPIC_API_KEY')
183
+ if anthropic_key:
184
+ return RerankConfig(
185
+ provider="anthropic",
186
+ api_key=anthropic_key
187
+ )
188
+
189
+ # Check OpenAI second
190
+ openai_key = os.environ.get('OPENAI_API_KEY')
191
+ if openai_key:
192
+ return RerankConfig(
193
+ provider="openai",
194
+ api_key=openai_key
195
+ )
196
+
197
+ # Check Google third
198
+ google_key = os.environ.get('GOOGLE_API_KEY')
199
+ if google_key:
200
+ return RerankConfig(
201
+ provider="google",
202
+ api_key=google_key
203
+ )
204
+
205
+ # Check Ollama fourth (local service)
206
+ ollama_host = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
207
+ if LLMReranker._is_ollama_available(ollama_host):
208
+ return RerankConfig(
209
+ provider="ollama"
210
+ )
211
+
212
+ # No providers available
213
+ return None
214
+
215
+ @staticmethod
216
+ def _is_ollama_available(host: str) -> bool:
217
+ """Check if Ollama is reachable at the given host."""
218
+ try:
219
+ response = requests.get(f'{host}/api/tags', timeout=2)
220
+ return response.status_code == 200
221
+ except:
222
+ return False
223
+
224
+ def _get_provider(self) -> LLMProvider:
225
+ """Get the appropriate LLM provider instance."""
226
+ if self.config.provider == "anthropic":
227
+ return AnthropicProvider(self.config.model, self.config.api_key)
228
+ elif self.config.provider == "openai":
229
+ return OpenAIProvider(self.config.model, self.config.api_key)
230
+ elif self.config.provider == "google":
231
+ return GoogleProvider(self.config.model, self.config.api_key)
232
+ elif self.config.provider == "ollama":
233
+ ollama_host = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
234
+ return OllamaLLMProvider(self.config.model, ollama_host)
235
+ else:
236
+ raise ValueError(f"Unknown provider: {self.config.provider}")
237
+
238
+ def _build_rerank_prompt(self, query: str, candidates: List[SearchResult], top_k: int) -> str:
239
+ """Build the rerank prompt for the LLM."""
240
+ lines = [
241
+ "You are a memory retrieval system. Given a query and a list of memory snippets,",
242
+ "return the indices of the most relevant snippets, ranked by relevance.",
243
+ "",
244
+ f"Query: \"{query}\"",
245
+ "",
246
+ "Snippets:"
247
+ ]
248
+
249
+ # Add numbered candidates (truncated to keep prompt reasonable)
250
+ for i, candidate in enumerate(candidates):
251
+ title_part = f" — {candidate.title}" if candidate.title else ""
252
+ text_truncated = candidate.text[:200] + "..." if len(candidate.text) > 200 else candidate.text
253
+ lines.append(f"[{i}] {text_truncated}{title_part}")
254
+
255
+ lines.extend([
256
+ "",
257
+ f"Return ONLY a JSON array of indices in order of relevance, e.g. [3, 0, 7, 1]",
258
+ f"Return at most {top_k} indices. Only include relevant results.",
259
+ "If none are relevant, return []."
260
+ ])
261
+
262
+ return "\n".join(lines)
263
+
264
+ def rerank(self, query: str, candidates: List[SearchResult], top_k: int = 5) -> List[SearchResult]:
265
+ """Re-rank candidates using the LLM.
266
+
267
+ Sends a structured prompt to the LLM with the query and candidate texts.
268
+ Returns re-ordered results with updated scores (1.0 = best match).
269
+
270
+ Args:
271
+ query: Original search query
272
+ candidates: BM25 search results to re-rank
273
+ top_k: Number of results to return after re-ranking
274
+
275
+ Returns:
276
+ List of SearchResult objects re-ordered by semantic relevance
277
+ """
278
+ # Handle empty candidates
279
+ if not candidates:
280
+ return []
281
+
282
+ try:
283
+ # Get LLM provider
284
+ provider = self._get_provider()
285
+
286
+ # Build prompt
287
+ prompt = self._build_rerank_prompt(query, candidates, top_k)
288
+
289
+ # Get LLM response
290
+ response = provider.complete(prompt, max_tokens=200)
291
+
292
+ # Parse response
293
+ try:
294
+ indices = json.loads(response.strip())
295
+ except json.JSONDecodeError:
296
+ # Retry once with stricter prompt
297
+ strict_prompt = prompt + "\n\nIMPORTANT: Return ONLY valid JSON array format like [1, 3, 0]. No other text."
298
+ try:
299
+ response = provider.complete(strict_prompt, max_tokens=200)
300
+ indices = json.loads(response.strip())
301
+ except (json.JSONDecodeError, Exception):
302
+ # Final fallback to original order
303
+ return self._fallback_to_original(candidates, top_k)
304
+
305
+ # Validate indices
306
+ if not isinstance(indices, list):
307
+ return self._fallback_to_original(candidates, top_k)
308
+
309
+ # Handle empty result
310
+ if not indices:
311
+ return []
312
+
313
+ # Reorder candidates based on LLM ranking
314
+ reranked = []
315
+ for rank, idx in enumerate(indices[:top_k]):
316
+ if isinstance(idx, int) and 0 <= idx < len(candidates):
317
+ candidate = candidates[idx]
318
+ # Normalize scores: first=1.0, linearly decreasing
319
+ score = 1.0 - (rank / len(indices)) if len(indices) > 1 else 1.0
320
+
321
+ # Create new result with updated score and engine
322
+ reranked_result = SearchResult(
323
+ frame_id=candidate.frame_id,
324
+ text=candidate.text,
325
+ title=candidate.title,
326
+ score=score,
327
+ tags=candidate.tags,
328
+ metadata=candidate.metadata,
329
+ engine="rerank"
330
+ )
331
+ reranked.append(reranked_result)
332
+
333
+ return reranked
334
+
335
+ except Exception:
336
+ # Fallback to original order on any error
337
+ return self._fallback_to_original(candidates, top_k)
338
+
339
+ def _fallback_to_original(self, candidates: List[SearchResult], top_k: int) -> List[SearchResult]:
340
+ """Fallback to original BM25 order with rerank engine label."""
341
+ fallback_results = []
342
+ for candidate in candidates[:top_k]:
343
+ fallback_result = SearchResult(
344
+ frame_id=candidate.frame_id,
345
+ text=candidate.text,
346
+ title=candidate.title,
347
+ score=candidate.score, # Keep original scores
348
+ tags=candidate.tags,
349
+ metadata=candidate.metadata,
350
+ engine="rerank" # Still label as rerank even though it fell back
351
+ )
352
+ fallback_results.append(fallback_result)
353
+ return fallback_results
ate/memory/search.py ADDED
@@ -0,0 +1,26 @@
1
+ """Search operations and result structures."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Dict, Any, Optional
5
+
6
+
7
+ @dataclass
8
+ class SearchResult:
9
+ """Result from a memory search operation.
10
+
11
+ Attributes:
12
+ frame_id: Unique identifier for the memory frame
13
+ text: The original content text
14
+ title: Title of the content (if provided)
15
+ score: Relevance score (higher = more relevant)
16
+ tags: List of tags associated with this content
17
+ metadata: Dictionary of additional metadata
18
+ engine: Search engine used ("lex" | "vec" | "hybrid")
19
+ """
20
+ frame_id: int
21
+ text: str
22
+ title: Optional[str]
23
+ score: float
24
+ tags: List[str]
25
+ metadata: Dict[str, Any]
26
+ engine: str = "lex"