claude-memory-agent 2.0.1 → 2.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. package/README.md +206 -206
  2. package/agent_card.py +186 -0
  3. package/bin/cli.js +327 -185
  4. package/bin/lib/banner.js +39 -0
  5. package/bin/lib/environment.js +166 -0
  6. package/bin/lib/installer.js +291 -0
  7. package/bin/lib/models.js +95 -0
  8. package/bin/lib/steps/advanced.js +101 -0
  9. package/bin/lib/steps/confirm.js +87 -0
  10. package/bin/lib/steps/model.js +57 -0
  11. package/bin/lib/steps/provider.js +65 -0
  12. package/bin/lib/steps/scope.js +59 -0
  13. package/bin/lib/steps/server.js +74 -0
  14. package/bin/lib/ui.js +75 -0
  15. package/bin/onboarding.js +164 -0
  16. package/bin/postinstall.js +35 -270
  17. package/config.py +103 -4
  18. package/dashboard.html +4902 -2689
  19. package/hooks/extract_memories.py +439 -0
  20. package/hooks/grounding-hook.py +422 -348
  21. package/hooks/pre_compact_hook.py +76 -0
  22. package/hooks/session_end.py +293 -192
  23. package/hooks/session_end_hook.py +149 -0
  24. package/hooks/session_start.py +227 -227
  25. package/hooks/stop_hook.py +372 -0
  26. package/install.py +972 -902
  27. package/main.py +5240 -2859
  28. package/mcp_server.py +451 -0
  29. package/package.json +58 -47
  30. package/requirements.txt +12 -8
  31. package/services/__init__.py +50 -50
  32. package/services/adaptive_ranker.py +272 -0
  33. package/services/agent_catalog.json +153 -0
  34. package/services/agent_registry.py +245 -730
  35. package/services/claude_md_sync.py +320 -4
  36. package/services/consolidation.py +417 -0
  37. package/services/curator.py +1606 -0
  38. package/services/database.py +4118 -2485
  39. package/services/embedding_pipeline.py +262 -0
  40. package/services/embeddings.py +493 -85
  41. package/services/memory_decay.py +408 -0
  42. package/services/native_memory_paths.py +86 -0
  43. package/services/native_memory_sync.py +496 -0
  44. package/services/response_manager.py +183 -0
  45. package/services/terminal_ui.py +199 -0
  46. package/services/tier_manager.py +235 -0
  47. package/services/websocket.py +26 -6
  48. package/skills/__init__.py +21 -1
  49. package/skills/confidence_tracker.py +441 -0
  50. package/skills/context.py +675 -0
  51. package/skills/curator.py +348 -0
  52. package/skills/search.py +444 -213
  53. package/skills/session_review.py +605 -0
  54. package/skills/store.py +484 -179
  55. package/terminal_dashboard.py +474 -0
  56. package/update_system.py +829 -817
  57. package/hooks/__pycache__/auto-detect-response.cpython-312.pyc +0 -0
  58. package/hooks/__pycache__/auto_capture.cpython-312.pyc +0 -0
  59. package/hooks/__pycache__/session_end.cpython-312.pyc +0 -0
  60. package/hooks/__pycache__/session_start.cpython-312.pyc +0 -0
  61. package/services/__pycache__/__init__.cpython-312.pyc +0 -0
  62. package/services/__pycache__/agent_registry.cpython-312.pyc +0 -0
  63. package/services/__pycache__/auth.cpython-312.pyc +0 -0
  64. package/services/__pycache__/auto_inject.cpython-312.pyc +0 -0
  65. package/services/__pycache__/claude_md_sync.cpython-312.pyc +0 -0
  66. package/services/__pycache__/cleanup.cpython-312.pyc +0 -0
  67. package/services/__pycache__/compaction_flush.cpython-312.pyc +0 -0
  68. package/services/__pycache__/confidence.cpython-312.pyc +0 -0
  69. package/services/__pycache__/daily_log.cpython-312.pyc +0 -0
  70. package/services/__pycache__/database.cpython-312.pyc +0 -0
  71. package/services/__pycache__/embeddings.cpython-312.pyc +0 -0
  72. package/services/__pycache__/insights.cpython-312.pyc +0 -0
  73. package/services/__pycache__/llm_analyzer.cpython-312.pyc +0 -0
  74. package/services/__pycache__/memory_md_sync.cpython-312.pyc +0 -0
  75. package/services/__pycache__/retry_queue.cpython-312.pyc +0 -0
  76. package/services/__pycache__/timeline.cpython-312.pyc +0 -0
  77. package/services/__pycache__/vector_index.cpython-312.pyc +0 -0
  78. package/services/__pycache__/websocket.cpython-312.pyc +0 -0
  79. package/skills/__pycache__/__init__.cpython-312.pyc +0 -0
  80. package/skills/__pycache__/admin.cpython-312.pyc +0 -0
  81. package/skills/__pycache__/checkpoint.cpython-312.pyc +0 -0
  82. package/skills/__pycache__/claude_md.cpython-312.pyc +0 -0
  83. package/skills/__pycache__/cleanup.cpython-312.pyc +0 -0
  84. package/skills/__pycache__/grounding.cpython-312.pyc +0 -0
  85. package/skills/__pycache__/insights.cpython-312.pyc +0 -0
  86. package/skills/__pycache__/natural_language.cpython-312.pyc +0 -0
  87. package/skills/__pycache__/retrieve.cpython-312.pyc +0 -0
  88. package/skills/__pycache__/search.cpython-312.pyc +0 -0
  89. package/skills/__pycache__/state.cpython-312.pyc +0 -0
  90. package/skills/__pycache__/store.cpython-312.pyc +0 -0
  91. package/skills/__pycache__/summarize.cpython-312.pyc +0 -0
  92. package/skills/__pycache__/timeline.cpython-312.pyc +0 -0
  93. package/skills/__pycache__/verification.cpython-312.pyc +0 -0
  94. package/test_automation.py +0 -221
  95. package/test_complete.py +0 -338
  96. package/test_full.py +0 -322
  97. package/verify_db.py +0 -134
@@ -1,21 +1,66 @@
1
- """Embedding service using Ollama with multi-model support.
1
+ """Embedding service with pluggable provider support.
2
2
 
3
- Includes health checks, graceful degradation, and model switching capabilities.
3
+ Supports multiple embedding backends (Ollama, sentence-transformers) via
4
+ a provider abstraction layer. Includes health checks, graceful degradation,
5
+ and model switching capabilities.
4
6
  """
5
7
  import os
6
8
  import time
7
9
  import asyncio
10
+ import logging
11
+ from abc import ABC, abstractmethod
12
+ from dataclasses import dataclass
13
+ from enum import Enum
8
14
  from typing import List, Optional, Dict, Any
9
- import ollama
10
15
  from dotenv import load_dotenv
11
16
 
17
+ # Conditional ollama import
18
+ try:
19
+ import ollama
20
+ HAS_OLLAMA = True
21
+ except ImportError:
22
+ HAS_OLLAMA = False
23
+
24
+ logger = logging.getLogger(__name__)
25
+
12
26
  load_dotenv()
13
27
 
14
28
  OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434")
15
29
  DEFAULT_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text")
16
- HEALTH_CHECK_TIMEOUT = float(os.getenv("OLLAMA_HEALTH_TIMEOUT", "2.0"))
30
+ HEALTH_CHECK_TIMEOUT = float(os.getenv("OLLAMA_HEALTH_TIMEOUT", "5.0"))
17
31
  HEALTH_CACHE_TTL = float(os.getenv("OLLAMA_HEALTH_CACHE_TTL", "30.0"))
18
32
 
33
+
34
+ class EmbeddingError(Enum):
35
+ """Distinguishable error codes for embedding failures."""
36
+ NONE = "none" # No error
37
+ EMPTY_TEXT = "empty_text" # Input text was empty or whitespace
38
+ OLLAMA_OFFLINE = "ollama_offline" # Ollama service not reachable
39
+ MODEL_NOT_LOADED = "model_not_loaded" # Model not available in Ollama
40
+ TIMEOUT = "timeout" # Embedding generation timed out
41
+ DEGRADED_MODE = "degraded_mode" # Service in degraded mode, not retrying yet
42
+ UNKNOWN = "unknown" # Unexpected error
43
+
44
+
45
+ @dataclass
46
+ class EmbeddingResult:
47
+ """Result from embedding generation with error context.
48
+
49
+ Allows callers to distinguish failure modes and take appropriate action:
50
+ - EMPTY_TEXT: Skip embedding, store memory without it
51
+ - OLLAMA_OFFLINE: Queue for later re-embedding
52
+ - TIMEOUT: Retry with smaller text or different model
53
+ - DEGRADED_MODE: Wait for auto-recovery
54
+ """
55
+ embedding: Optional[List[float]]
56
+ error: EmbeddingError = EmbeddingError.NONE
57
+ error_message: Optional[str] = None
58
+
59
+ @property
60
+ def ok(self) -> bool:
61
+ return self.embedding is not None and self.error == EmbeddingError.NONE
62
+
63
+
19
64
  # Model configurations: model_name -> dimension
20
65
  MODEL_CONFIGS = {
21
66
  "nomic-embed-text": {"dimension": 768, "description": "General purpose, fast"},
@@ -23,29 +68,200 @@ MODEL_CONFIGS = {
23
68
  "all-minilm": {"dimension": 384, "description": "Lightweight, fast"},
24
69
  "snowflake-arctic-embed": {"dimension": 1024, "description": "High quality, multilingual"},
25
70
  "bge-m3": {"dimension": 1024, "description": "Multilingual, dense retrieval"},
71
+ "gte-large-en-v1.5": {"alias_for": "Alibaba-NLP/gte-large-en-v1.5"},
72
+ "Alibaba-NLP/gte-large-en-v1.5": {"dimension": 1024, "description": "High quality, best STS scores"},
73
+ "all-MiniLM-L6-v2": {"dimension": 384, "description": "Lightweight, fast, sentence-transformers"},
74
+ "BAAI/bge-base-en-v1.5": {"dimension": 768, "description": "Good balance, sentence-transformers"},
26
75
  "default": {"alias_for": "nomic-embed-text"},
27
76
  }
28
77
 
29
78
 
79
+ # ---------------------------------------------------------------------------
80
+ # Provider abstraction
81
+ # ---------------------------------------------------------------------------
82
+
83
+ class EmbeddingProvider(ABC):
84
+ """Abstract base class for embedding providers."""
85
+
86
+ @abstractmethod
87
+ def embed(self, text: str) -> List[float]:
88
+ """Generate embedding for a single text. Runs synchronously."""
89
+ ...
90
+
91
+ @abstractmethod
92
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
93
+ """Generate embeddings for multiple texts. Runs synchronously."""
94
+ ...
95
+
96
+ @abstractmethod
97
+ def check_health(self) -> dict:
98
+ """Check provider health. Returns dict with 'healthy', 'error' keys."""
99
+ ...
100
+
101
+ @abstractmethod
102
+ def get_dimension(self) -> int:
103
+ """Return the embedding dimension."""
104
+ ...
105
+
106
+ @abstractmethod
107
+ def get_model_name(self) -> str:
108
+ """Return the model name."""
109
+ ...
110
+
111
+
112
+ class OllamaProvider(EmbeddingProvider):
113
+ """Embedding provider backed by a local Ollama instance."""
114
+
115
+ def __init__(self, host: str = OLLAMA_HOST, model: str = DEFAULT_MODEL):
116
+ if not HAS_OLLAMA:
117
+ raise RuntimeError(
118
+ "ollama package not installed. Run: pip install ollama"
119
+ )
120
+ self.host = host
121
+ self.model = model
122
+ self.client = ollama.Client(host=host)
123
+
124
+ def embed(self, text: str) -> List[float]:
125
+ response = self.client.embeddings(model=self.model, prompt=text)
126
+ return response["embedding"]
127
+
128
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
129
+ # Ollama has no native batch endpoint; call embed() sequentially
130
+ return [self.embed(t) for t in texts]
131
+
132
+ def check_health(self) -> dict:
133
+ try:
134
+ models = self.client.list()
135
+ model_names = [
136
+ m.get("name", m.get("model", ""))
137
+ for m in models.get("models", [])
138
+ ]
139
+ model_loaded = any(self.model in name for name in model_names)
140
+ return {
141
+ "healthy": True,
142
+ "model": self.model,
143
+ "model_loaded": model_loaded,
144
+ "provider": "ollama",
145
+ "host": self.host,
146
+ "error": None,
147
+ "available_models": model_names,
148
+ }
149
+ except Exception as e:
150
+ return {
151
+ "healthy": False,
152
+ "model": self.model,
153
+ "provider": "ollama",
154
+ "host": self.host,
155
+ "error": str(e),
156
+ }
157
+
158
+ def get_dimension(self) -> int:
159
+ config = MODEL_CONFIGS.get(self.model, {})
160
+ if "alias_for" in config:
161
+ config = MODEL_CONFIGS.get(config["alias_for"], {})
162
+ return config.get("dimension", 768)
163
+
164
+ def get_model_name(self) -> str:
165
+ return self.model
166
+
167
+
168
+ class SentenceTransformerProvider(EmbeddingProvider):
169
+ """Embedding provider using the sentence-transformers library."""
170
+
171
+ def __init__(self, model: str = "Alibaba-NLP/gte-large-en-v1.5"):
172
+ try:
173
+ from sentence_transformers import SentenceTransformer
174
+ except ImportError:
175
+ raise RuntimeError(
176
+ "sentence-transformers package not installed. "
177
+ "Run: pip install sentence-transformers"
178
+ )
179
+
180
+ self.model_name = model
181
+ self._model = SentenceTransformer(model, trust_remote_code=True)
182
+ self._dimension = self._model.get_sentence_embedding_dimension()
183
+
184
+ def embed(self, text: str) -> List[float]:
185
+ embedding = self._model.encode(text, normalize_embeddings=True)
186
+ return embedding.tolist()
187
+
188
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
189
+ embeddings = self._model.encode(
190
+ texts, normalize_embeddings=True, batch_size=32
191
+ )
192
+ return embeddings.tolist()
193
+
194
+ def check_health(self) -> dict:
195
+ return {
196
+ "healthy": True,
197
+ "model": self.model_name,
198
+ "provider": "sentence-transformers",
199
+ "error": None,
200
+ }
201
+
202
+ def get_dimension(self) -> int:
203
+ return self._dimension
204
+
205
+ def get_model_name(self) -> str:
206
+ return self.model_name
207
+
208
+
209
+ # ---------------------------------------------------------------------------
210
+ # EmbeddingService (public API unchanged)
211
+ # ---------------------------------------------------------------------------
212
+
30
213
  class EmbeddingService:
31
- """Service for generating embeddings using Ollama with multi-model support.
214
+ """Service for generating embeddings with pluggable provider backends.
32
215
 
33
216
  Features:
34
217
  - Multiple model support with automatic dimension handling
35
- - Health check with caching to avoid hammering Ollama
36
- - Graceful degradation: returns None when Ollama unavailable
37
- - Timeout handling for unresponsive Ollama instances
218
+ - Health check with caching to avoid hammering the provider
219
+ - Graceful degradation: returns None when provider unavailable
220
+ - Timeout handling for unresponsive backends
38
221
  - Model switching without data loss
39
222
  """
40
223
 
41
- def __init__(self, model: Optional[str] = None):
42
- self.host = OLLAMA_HOST
43
- self.client = ollama.Client(host=OLLAMA_HOST)
224
+ def __init__(
225
+ self,
226
+ provider_type: str = "sentence-transformers",
227
+ model: Optional[str] = None,
228
+ ):
229
+ self.provider_type = provider_type
230
+
231
+ # Determine model
232
+ if model:
233
+ self.model = self._resolve_model(model)
234
+ elif provider_type == "sentence-transformers":
235
+ self.model = "Alibaba-NLP/gte-large-en-v1.5"
236
+ else:
237
+ self.model = DEFAULT_MODEL # nomic-embed-text
238
+
239
+ # Guard: Ollama-only model names don't exist on HuggingFace
240
+ OLLAMA_ONLY_MODELS = {"nomic-embed-text", "mxbai-embed-large", "all-minilm",
241
+ "snowflake-arctic-embed", "bge-m3"}
242
+ if provider_type == "sentence-transformers" and self.model in OLLAMA_ONLY_MODELS:
243
+ logger.warning(
244
+ f"Model '{self.model}' is an Ollama-only model but provider is "
245
+ f"sentence-transformers. Falling back to Alibaba-NLP/gte-large-en-v1.5. "
246
+ f"Update EMBEDDING_MODEL in your .env to fix this."
247
+ )
248
+ self.model = "Alibaba-NLP/gte-large-en-v1.5"
44
249
 
45
- # Resolve model (handle aliases)
46
- self.model = self._resolve_model(model or DEFAULT_MODEL)
47
250
  self._model_config = self._get_model_config(self.model)
48
251
 
252
+ # Create provider
253
+ if provider_type == "ollama":
254
+ self._provider = OllamaProvider(host=OLLAMA_HOST, model=self.model)
255
+ self.host = OLLAMA_HOST
256
+ elif provider_type == "sentence-transformers":
257
+ self._provider = SentenceTransformerProvider(model=self.model)
258
+ self.host = "local"
259
+ else:
260
+ raise ValueError(
261
+ f"Unknown provider: {provider_type}. "
262
+ "Use 'ollama' or 'sentence-transformers'"
263
+ )
264
+
49
265
  # Health check caching
50
266
  self._health_status: Optional[bool] = None
51
267
  self._health_last_check: float = 0
@@ -60,6 +276,10 @@ class EmbeddingService:
60
276
  self._available_models: Optional[List[str]] = None
61
277
  self._models_last_check: float = 0
62
278
 
279
+ # ------------------------------------------------------------------
280
+ # Internal helpers
281
+ # ------------------------------------------------------------------
282
+
63
283
  def _resolve_model(self, model: str) -> str:
64
284
  """Resolve model aliases to actual model names."""
65
285
  config = MODEL_CONFIGS.get(model, {})
@@ -71,11 +291,24 @@ class EmbeddingService:
71
291
  """Get configuration for a model."""
72
292
  if model in MODEL_CONFIGS:
73
293
  return MODEL_CONFIGS[model]
74
- # Default config for unknown models
75
294
  return {"dimension": 768, "description": "Unknown model"}
76
295
 
296
+ def _enter_degraded_mode(self):
297
+ """Enter degraded mode when provider is unavailable."""
298
+ if not self._degraded_mode:
299
+ self._degraded_mode = True
300
+ self._degraded_since = time.time()
301
+
302
+ def _is_local_provider(self) -> bool:
303
+ """Return True if the provider runs locally with no remote dependency."""
304
+ return self.provider_type == "sentence-transformers"
305
+
306
+ # ------------------------------------------------------------------
307
+ # Health
308
+ # ------------------------------------------------------------------
309
+
77
310
  async def check_health(self, force: bool = False) -> Dict[str, Any]:
78
- """Check if Ollama is healthy and responsive.
311
+ """Check if the embedding provider is healthy and responsive.
79
312
 
80
313
  Args:
81
314
  force: If True, bypass cache and check immediately
@@ -93,52 +326,45 @@ class EmbeddingService:
93
326
  "cached": True,
94
327
  "model": self.model,
95
328
  "host": self.host,
329
+ "provider": self.provider_type,
96
330
  "error": self._health_error,
97
- "degraded_mode": self._degraded_mode
331
+ "degraded_mode": self._degraded_mode,
98
332
  }
99
333
 
100
- # Perform health check with timeout
101
334
  start_time = time.time()
102
335
  try:
103
336
  loop = asyncio.get_event_loop()
104
337
 
105
- def _check():
106
- # Try to list models to verify Ollama is responding
107
- models = self.client.list()
108
- model_names = [m.get('name', m.get('model', '')) for m in models.get('models', [])]
109
- # Check if our model is available
110
- model_loaded = any(self.model in name for name in model_names)
111
- return models, model_loaded, model_names
112
-
113
- # Run with timeout
114
- models, model_loaded, model_names = await asyncio.wait_for(
115
- loop.run_in_executor(None, _check),
116
- timeout=HEALTH_CHECK_TIMEOUT
338
+ health_result = await asyncio.wait_for(
339
+ loop.run_in_executor(None, self._provider.check_health),
340
+ timeout=HEALTH_CHECK_TIMEOUT,
117
341
  )
118
342
 
119
343
  latency_ms = (time.time() - start_time) * 1000
120
344
 
121
- self._health_status = True
345
+ self._health_status = health_result.get("healthy", False)
122
346
  self._health_last_check = now
123
- self._health_error = None
124
- self._available_models = model_names
125
- self._models_last_check = now
347
+ self._health_error = health_result.get("error")
348
+ self._available_models = health_result.get("available_models")
349
+ if self._available_models is not None:
350
+ self._models_last_check = now
126
351
 
127
- # Exit degraded mode if we were in it
128
- if self._degraded_mode:
352
+ # Exit degraded mode on success
353
+ if self._health_status and self._degraded_mode:
129
354
  self._degraded_mode = False
130
355
  self._degraded_since = None
131
356
 
132
357
  return {
133
- "healthy": True,
358
+ "healthy": self._health_status,
134
359
  "cached": False,
135
360
  "model": self.model,
136
- "model_loaded": model_loaded,
361
+ "model_loaded": health_result.get("model_loaded", True),
137
362
  "host": self.host,
363
+ "provider": self.provider_type,
138
364
  "latency_ms": round(latency_ms, 2),
139
- "error": None,
140
- "degraded_mode": False,
141
- "available_models": model_names
365
+ "error": self._health_error,
366
+ "degraded_mode": False if self._health_status else self._degraded_mode,
367
+ "available_models": self._available_models,
142
368
  }
143
369
 
144
370
  except asyncio.TimeoutError:
@@ -152,8 +378,9 @@ class EmbeddingService:
152
378
  "cached": False,
153
379
  "model": self.model,
154
380
  "host": self.host,
381
+ "provider": self.provider_type,
155
382
  "error": self._health_error,
156
- "degraded_mode": True
383
+ "degraded_mode": True,
157
384
  }
158
385
 
159
386
  except Exception as e:
@@ -167,16 +394,11 @@ class EmbeddingService:
167
394
  "cached": False,
168
395
  "model": self.model,
169
396
  "host": self.host,
397
+ "provider": self.provider_type,
170
398
  "error": self._health_error,
171
- "degraded_mode": True
399
+ "degraded_mode": True,
172
400
  }
173
401
 
174
- def _enter_degraded_mode(self):
175
- """Enter degraded mode when Ollama is unavailable."""
176
- if not self._degraded_mode:
177
- self._degraded_mode = True
178
- self._degraded_since = time.time()
179
-
180
402
  def is_degraded(self) -> bool:
181
403
  """Check if service is in degraded mode."""
182
404
  return self._degraded_mode
@@ -187,11 +409,15 @@ class EmbeddingService:
187
409
  return time.time() - self._degraded_since
188
410
  return None
189
411
 
412
+ # ------------------------------------------------------------------
413
+ # Embedding generation
414
+ # ------------------------------------------------------------------
415
+
190
416
  async def generate_embedding(
191
417
  self,
192
418
  text: str,
193
419
  model: Optional[str] = None,
194
- fallback_on_error: bool = True
420
+ fallback_on_error: bool = True,
195
421
  ) -> Optional[List[float]]:
196
422
  """Generate embedding for a single text.
197
423
 
@@ -201,35 +427,34 @@ class EmbeddingService:
201
427
  fallback_on_error: If True, return None instead of raising on error
202
428
 
203
429
  Returns:
204
- List of floats (embedding) or None if Ollama unavailable and fallback enabled
430
+ List of floats (embedding) or None if provider unavailable and fallback enabled
205
431
  """
206
- use_model = self._resolve_model(model) if model else self.model
207
-
208
- # Quick check if we're in degraded mode
209
- if self._degraded_mode:
210
- # Check if we should retry (every 30s)
432
+ # For local providers, skip degraded-mode gating
433
+ if not self._is_local_provider() and self._degraded_mode:
211
434
  if time.time() - self._health_last_check >= self._health_cache_ttl:
212
435
  health = await self.check_health(force=True)
213
436
  if not health["healthy"]:
214
437
  if fallback_on_error:
215
438
  return None
216
- raise ConnectionError(f"Ollama unavailable: {health['error']}")
439
+ raise ConnectionError(
440
+ f"Provider unavailable: {health['error']}"
441
+ )
217
442
  elif fallback_on_error:
218
443
  return None
219
444
  else:
220
- raise ConnectionError(f"Ollama unavailable (degraded mode): {self._health_error}")
445
+ raise ConnectionError(
446
+ f"Provider unavailable (degraded mode): {self._health_error}"
447
+ )
221
448
 
222
449
  try:
223
450
  loop = asyncio.get_event_loop()
224
451
 
225
452
  def _embed():
226
- response = self.client.embeddings(model=use_model, prompt=text)
227
- return response["embedding"]
453
+ return self._provider.embed(text)
228
454
 
229
- # Run with timeout
230
455
  embedding = await asyncio.wait_for(
231
456
  loop.run_in_executor(None, _embed),
232
- timeout=30.0 # 30s timeout for embedding generation
457
+ timeout=30.0,
233
458
  )
234
459
  return embedding
235
460
 
@@ -241,7 +466,6 @@ class EmbeddingService:
241
466
  raise
242
467
 
243
468
  except Exception as e:
244
- # Check if it's a connection error
245
469
  error_str = str(e).lower()
246
470
  if "connection" in error_str or "refused" in error_str or "timeout" in error_str:
247
471
  self._enter_degraded_mode()
@@ -251,33 +475,201 @@ class EmbeddingService:
251
475
  return None
252
476
  raise
253
477
 
478
+ async def generate_embedding_with_status(
479
+ self,
480
+ text: str,
481
+ model: Optional[str] = None,
482
+ ) -> EmbeddingResult:
483
+ """Generate embedding with detailed error status.
484
+
485
+ Unlike generate_embedding() which returns None for all failures,
486
+ this method returns an EmbeddingResult with a specific error code
487
+ so callers can distinguish:
488
+ - Empty input text
489
+ - Ollama offline (connection refused)
490
+ - Model not loaded
491
+ - Generation timeout
492
+ - Degraded mode (waiting for auto-recovery)
493
+
494
+ Args:
495
+ text: Text to embed
496
+ model: Optional model override
497
+
498
+ Returns:
499
+ EmbeddingResult with embedding and error details
500
+ """
501
+ # Validate input
502
+ if not text or not text.strip():
503
+ return EmbeddingResult(
504
+ embedding=None,
505
+ error=EmbeddingError.EMPTY_TEXT,
506
+ error_message="Input text is empty or whitespace-only",
507
+ )
508
+
509
+ # For local providers, skip degraded-mode gating
510
+ if not self._is_local_provider() and self._degraded_mode:
511
+ if time.time() - self._health_last_check >= self._health_cache_ttl:
512
+ health = await self.check_health(force=True)
513
+ if not health["healthy"]:
514
+ return EmbeddingResult(
515
+ embedding=None,
516
+ error=EmbeddingError.OLLAMA_OFFLINE,
517
+ error_message=f"Provider unavailable: {health.get('error', 'unknown')}",
518
+ )
519
+ else:
520
+ return EmbeddingResult(
521
+ embedding=None,
522
+ error=EmbeddingError.DEGRADED_MODE,
523
+ error_message=(
524
+ f"Service in degraded mode since {self._degraded_since:.0f}. "
525
+ f"Next retry in {self._health_cache_ttl - (time.time() - self._health_last_check):.0f}s"
526
+ ),
527
+ )
528
+
529
+ try:
530
+ loop = asyncio.get_event_loop()
531
+
532
+ def _embed():
533
+ return self._provider.embed(text)
534
+
535
+ embedding = await asyncio.wait_for(
536
+ loop.run_in_executor(None, _embed),
537
+ timeout=30.0,
538
+ )
539
+
540
+ # Exit degraded mode on success
541
+ if self._degraded_mode:
542
+ self._degraded_mode = False
543
+ self._degraded_since = None
544
+
545
+ return EmbeddingResult(embedding=embedding)
546
+
547
+ except asyncio.TimeoutError:
548
+ self._enter_degraded_mode()
549
+ self._health_error = "Embedding generation timed out"
550
+ return EmbeddingResult(
551
+ embedding=None,
552
+ error=EmbeddingError.TIMEOUT,
553
+ error_message=f"Embedding generation timed out after 30s for model {self.model}",
554
+ )
555
+
556
+ except Exception as e:
557
+ error_str = str(e).lower()
558
+
559
+ if "connection" in error_str or "refused" in error_str:
560
+ self._enter_degraded_mode()
561
+ self._health_error = str(e)
562
+ return EmbeddingResult(
563
+ embedding=None,
564
+ error=EmbeddingError.OLLAMA_OFFLINE,
565
+ error_message=f"Provider not reachable: {e}",
566
+ )
567
+
568
+ if "model" in error_str and (
569
+ "not found" in error_str or "not exist" in error_str
570
+ ):
571
+ return EmbeddingResult(
572
+ embedding=None,
573
+ error=EmbeddingError.MODEL_NOT_LOADED,
574
+ error_message=f"Model '{self.model}' not available: {e}",
575
+ )
576
+
577
+ if "timeout" in error_str:
578
+ self._enter_degraded_mode()
579
+ self._health_error = str(e)
580
+ return EmbeddingResult(
581
+ embedding=None,
582
+ error=EmbeddingError.TIMEOUT,
583
+ error_message=f"Timeout: {e}",
584
+ )
585
+
586
+ logger.warning(f"Unexpected embedding error: {e}")
587
+ return EmbeddingResult(
588
+ embedding=None,
589
+ error=EmbeddingError.UNKNOWN,
590
+ error_message=str(e),
591
+ )
592
+
254
593
  async def generate_embeddings(
255
594
  self,
256
595
  texts: List[str],
257
596
  model: Optional[str] = None,
258
- fallback_on_error: bool = True
597
+ fallback_on_error: bool = True,
598
+ batch_size: int = 10,
259
599
  ) -> List[Optional[List[float]]]:
260
600
  """Generate embeddings for multiple texts.
261
601
 
602
+ For providers with native batch support (sentence-transformers), uses
603
+ the provider's batch method directly. Otherwise falls back to
604
+ concurrent individual requests.
605
+
262
606
  Args:
263
607
  texts: List of texts to embed
264
608
  model: Optional model override
265
609
  fallback_on_error: If True, include None for failed embeddings
610
+ batch_size: Number of concurrent embedding requests per batch
266
611
 
267
612
  Returns:
268
613
  List of embeddings (or None for failed ones if fallback enabled)
269
614
  """
270
- embeddings = []
271
- for text in texts:
272
- embedding = await self.generate_embedding(text, model, fallback_on_error)
273
- embeddings.append(embedding)
274
- return embeddings
615
+ if not texts:
616
+ return []
617
+
618
+ # sentence-transformers has efficient native batching
619
+ if self.provider_type == "sentence-transformers":
620
+ try:
621
+ loop = asyncio.get_event_loop()
622
+
623
+ def _batch_embed():
624
+ return self._provider.embed_batch(texts)
625
+
626
+ results = await asyncio.wait_for(
627
+ loop.run_in_executor(None, _batch_embed),
628
+ timeout=max(30.0, len(texts) * 2.0),
629
+ )
630
+ return results
631
+ except Exception as e:
632
+ if fallback_on_error:
633
+ logger.warning(f"Batch embedding failed: {e}")
634
+ return [None] * len(texts)
635
+ raise
636
+
637
+ # For other providers, use concurrent individual requests
638
+ results: List[Optional[List[float]]] = [None] * len(texts)
639
+
640
+ for batch_start in range(0, len(texts), batch_size):
641
+ batch_texts = texts[batch_start : batch_start + batch_size]
642
+ batch_results = await asyncio.gather(
643
+ *[
644
+ self.generate_embedding(text, model, fallback_on_error)
645
+ for text in batch_texts
646
+ ],
647
+ return_exceptions=True,
648
+ )
649
+
650
+ for i, result in enumerate(batch_results):
651
+ idx = batch_start + i
652
+ if isinstance(result, Exception):
653
+ if fallback_on_error:
654
+ results[idx] = None
655
+ else:
656
+ raise result
657
+ else:
658
+ results[idx] = result
659
+
660
+ return results
661
+
662
+ # ------------------------------------------------------------------
663
+ # Model / status helpers
664
+ # ------------------------------------------------------------------
275
665
 
276
666
  def get_dimension(self, model: Optional[str] = None) -> int:
277
667
  """Return the embedding dimension for a model."""
278
- use_model = self._resolve_model(model) if model else self.model
279
- config = self._get_model_config(use_model)
280
- return config.get("dimension", 768)
668
+ if model:
669
+ use_model = self._resolve_model(model)
670
+ config = self._get_model_config(use_model)
671
+ return config.get("dimension", 768)
672
+ return self._provider.get_dimension()
281
673
 
282
674
  def get_current_model(self) -> str:
283
675
  """Get the current default model."""
@@ -297,7 +689,7 @@ class EmbeddingService:
297
689
  models = []
298
690
  for name, config in MODEL_CONFIGS.items():
299
691
  if "alias_for" in config:
300
- continue # Skip aliases
692
+ continue
301
693
  models.append({
302
694
  "name": name,
303
695
  "dimension": config.get("dimension", 768),
@@ -305,24 +697,32 @@ class EmbeddingService:
305
697
  "is_current": name == self.model,
306
698
  "available_in_ollama": (
307
699
  any(name in m for m in (self._available_models or []))
308
- if self._available_models else None
309
- )
700
+ if self._available_models
701
+ else None
702
+ ),
310
703
  })
311
704
  return models
312
705
 
313
706
  async def get_ollama_models(self) -> List[str]:
314
707
  """Get list of models currently available in Ollama."""
708
+ if self.provider_type != "ollama":
709
+ return []
710
+
315
711
  if self._available_models and (time.time() - self._models_last_check) < 60:
316
712
  return self._available_models
317
713
 
318
714
  try:
319
715
  loop = asyncio.get_event_loop()
320
- models = await loop.run_in_executor(None, self.client.list)
321
- model_names = [m.get('name', m.get('model', '')) for m in models.get('models', [])]
716
+ provider: OllamaProvider = self._provider # type: ignore[assignment]
717
+ models = await loop.run_in_executor(None, provider.client.list)
718
+ model_names = [
719
+ m.get("name", m.get("model", ""))
720
+ for m in models.get("models", [])
721
+ ]
322
722
  self._available_models = model_names
323
723
  self._models_last_check = time.time()
324
724
  return model_names
325
- except:
725
+ except Exception:
326
726
  return self._available_models or []
327
727
 
328
728
  def get_status(self) -> Dict[str, Any]:
@@ -331,28 +731,36 @@ class EmbeddingService:
331
731
  "model": self.model,
332
732
  "dimension": self.get_dimension(),
333
733
  "host": self.host,
734
+ "provider": self.provider_type,
334
735
  "degraded_mode": self._degraded_mode,
335
736
  "degraded_since": self._degraded_since,
336
737
  "degraded_duration_seconds": self.get_degraded_duration(),
337
738
  "last_health_check": self._health_last_check,
338
739
  "last_health_status": self._health_status,
339
740
  "last_health_error": self._health_error,
340
- "available_models_in_ollama": self._available_models
741
+ "available_models_in_ollama": self._available_models,
341
742
  }
342
743
 
343
744
 
344
- # Global registry of embedding services per model
745
+ # Global registry of embedding services per provider:model
345
746
  _embedding_services: Dict[str, EmbeddingService] = {}
346
747
 
347
748
 
348
- def get_embedding_service(model: Optional[str] = None) -> EmbeddingService:
349
- """Get an embedding service for a specific model.
749
+ def get_embedding_service(
750
+ model: Optional[str] = None,
751
+ provider_type: Optional[str] = None,
752
+ ) -> EmbeddingService:
753
+ """Get an embedding service for a specific provider and model.
350
754
 
351
- Uses a shared instance per model to maintain health check state.
755
+ Uses a shared instance per provider:model to maintain health check state.
352
756
  """
757
+ provider = provider_type or os.getenv("EMBEDDING_PROVIDER", "sentence-transformers")
353
758
  model_key = model or DEFAULT_MODEL
759
+ cache_key = f"{provider}:{model_key}"
354
760
 
355
- if model_key not in _embedding_services:
356
- _embedding_services[model_key] = EmbeddingService(model_key)
761
+ if cache_key not in _embedding_services:
762
+ _embedding_services[cache_key] = EmbeddingService(
763
+ provider_type=provider, model=model_key
764
+ )
357
765
 
358
- return _embedding_services[model_key]
766
+ return _embedding_services[cache_key]