cembedding 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cembedding/__init__.py +3 -0
- cembedding/__main__.py +6 -0
- cembedding/_vendored_mcp_common/__init__.py +14 -0
- cembedding/_vendored_mcp_common/mcp_utils.py +152 -0
- cembedding/_vendored_mcp_common/validation.py +65 -0
- cembedding/download_model.py +132 -0
- cembedding/server.py +1351 -0
- cembedding-0.5.0.dist-info/METADATA +138 -0
- cembedding-0.5.0.dist-info/RECORD +12 -0
- cembedding-0.5.0.dist-info/WHEEL +4 -0
- cembedding-0.5.0.dist-info/entry_points.txt +3 -0
- cembedding-0.5.0.dist-info/licenses/LICENSE +21 -0
cembedding/server.py
ADDED
|
@@ -0,0 +1,1351 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cloto MCP Server: Vector Embedding
|
|
3
|
+
Pluggable embedding provider with HTTP endpoint for inter-server communication.
|
|
4
|
+
Providers: api_openai (OpenAI-compatible API), onnx_miniml (local MiniLM ONNX).
|
|
5
|
+
|
|
6
|
+
v0.2.0: Vector Index — persistent index + search endpoints for centralized vector search.
|
|
7
|
+
|
|
8
|
+
Design: docs/CPERSONA_MEMORY_DESIGN.md Section 5
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import platform as _platform
|
|
15
|
+
import struct
|
|
16
|
+
import sys
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
|
|
19
|
+
import httpx
|
|
20
|
+
import numpy as np
|
|
21
|
+
from aiohttp import web
|
|
22
|
+
from mcp.server.stdio import stdio_server
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
from cembedding._vendored_mcp_common.mcp_utils import ToolRegistry
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# ============================================================
|
|
30
|
+
# Configuration
|
|
31
|
+
# ============================================================
|
|
32
|
+
|
|
33
|
+
EMBEDDING_PROVIDER = os.environ.get("EMBEDDING_PROVIDER", "api_openai")
|
|
34
|
+
EMBEDDING_HTTP_PORT = int(os.environ.get("EMBEDDING_HTTP_PORT", "8401"))
|
|
35
|
+
if not (1 <= EMBEDDING_HTTP_PORT <= 65535):
|
|
36
|
+
raise ValueError(f"EMBEDDING_HTTP_PORT must be 1-65535, got {EMBEDDING_HTTP_PORT}")
|
|
37
|
+
EMBEDDING_API_KEY = os.environ.get("EMBEDDING_API_KEY", "")
|
|
38
|
+
EMBEDDING_API_URL = os.environ.get("EMBEDDING_API_URL", "https://api.openai.com/v1/embeddings")
|
|
39
|
+
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "") # provider-dependent default
|
|
40
|
+
EMBEDDING_TIMEOUT = int(os.environ.get("EMBEDDING_TIMEOUT_SECS", "30"))
|
|
41
|
+
|
|
42
|
+
# Vector Index (v0.2.0)
|
|
43
|
+
EMBEDDING_INDEX_ENABLED = os.environ.get("EMBEDDING_INDEX_ENABLED", "true").lower() == "true"
|
|
44
|
+
EMBEDDING_INDEX_DB_PATH = os.environ.get("EMBEDDING_INDEX_DB_PATH", "data/embedding_index.db")
|
|
45
|
+
|
|
46
|
+
# ONNX tokenization max sequence length (1-8192). MiniLM is clamped to 512
|
|
47
|
+
# internally (positional embeddings cap). Jina-v5-nano supports up to 8192.
|
|
48
|
+
ONNX_MAX_SEQ_LEN = int(os.environ.get("ONNX_MAX_SEQ_LEN", "2048"))
|
|
49
|
+
if not (1 <= ONNX_MAX_SEQ_LEN <= 8192):
|
|
50
|
+
raise ValueError(f"ONNX_MAX_SEQ_LEN must be 1-8192, got {ONNX_MAX_SEQ_LEN}")
|
|
51
|
+
|
|
52
|
+
# ONNX-specific — resolve relative paths against CLOTO_PROJECT_DIR when running
|
|
53
|
+
# inside a sandbox (isolation changes the working directory).
|
|
54
|
+
_project_dir = os.environ.get("CLOTO_PROJECT_DIR", "")
|
|
55
|
+
_MODEL_DIRS = {
|
|
56
|
+
"onnx_miniml": "data/models/all-MiniLM-L6-v2",
|
|
57
|
+
"onnx_jina_v5_nano": "data/models/jina-embeddings-v5-text-nano",
|
|
58
|
+
"onnx_bge_m3": "data/models/bge-m3",
|
|
59
|
+
"mlx_bge_m3": "data/models/bge-m3-mlx",
|
|
60
|
+
}
|
|
61
|
+
_default_model_dir = _MODEL_DIRS.get(EMBEDDING_PROVIDER, "data/models/all-MiniLM-L6-v2")
|
|
62
|
+
ONNX_MODEL_DIR = os.environ.get("ONNX_MODEL_DIR", "")
|
|
63
|
+
if not ONNX_MODEL_DIR:
|
|
64
|
+
if _project_dir and not os.path.isabs(_default_model_dir):
|
|
65
|
+
ONNX_MODEL_DIR = os.path.join(_project_dir, _default_model_dir)
|
|
66
|
+
else:
|
|
67
|
+
ONNX_MODEL_DIR = _default_model_dir
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _select_ort_providers() -> list:
|
|
71
|
+
"""Select ONNX Runtime execution providers with cross-platform fallback.
|
|
72
|
+
|
|
73
|
+
Priority when ONNX_EP_PREFERENCE is empty (auto-detect):
|
|
74
|
+
1. CoreMLExecutionProvider (macOS — standard onnxruntime)
|
|
75
|
+
2. DmlExecutionProvider (Windows — standard onnxruntime)
|
|
76
|
+
3. CUDAExecutionProvider (Linux/Windows NVIDIA — requires onnxruntime-gpu)
|
|
77
|
+
4. ROCmExecutionProvider (Linux AMD — requires onnxruntime ROCm build)
|
|
78
|
+
5. CPUExecutionProvider (always appended as terminal fallback)
|
|
79
|
+
|
|
80
|
+
When ONNX_EP_PREFERENCE is set, use the comma-separated list but always
|
|
81
|
+
filter against get_available_providers() and always ensure CPUExecutionProvider
|
|
82
|
+
is present so session creation cannot fail for lack of any provider.
|
|
83
|
+
|
|
84
|
+
Fail-open: if onnxruntime import or get_available_providers() raises,
|
|
85
|
+
return ["CPUExecutionProvider"] so the caller can still attempt to load.
|
|
86
|
+
"""
|
|
87
|
+
try:
|
|
88
|
+
import onnxruntime as ort
|
|
89
|
+
|
|
90
|
+
available = set(ort.get_available_providers())
|
|
91
|
+
except Exception:
|
|
92
|
+
return ["CPUExecutionProvider"]
|
|
93
|
+
|
|
94
|
+
preference = os.environ.get("ONNX_EP_PREFERENCE", "").strip()
|
|
95
|
+
|
|
96
|
+
if preference:
|
|
97
|
+
requested = [p.strip() for p in preference.split(",") if p.strip()]
|
|
98
|
+
providers = [p for p in requested if p in available]
|
|
99
|
+
else:
|
|
100
|
+
providers = []
|
|
101
|
+
for candidate in (
|
|
102
|
+
"CoreMLExecutionProvider", # macOS
|
|
103
|
+
"DmlExecutionProvider", # Windows (DirectML)
|
|
104
|
+
"CUDAExecutionProvider", # Linux/Windows NVIDIA (onnxruntime-gpu)
|
|
105
|
+
"ROCmExecutionProvider", # Linux AMD (onnxruntime ROCm)
|
|
106
|
+
):
|
|
107
|
+
if candidate in available:
|
|
108
|
+
providers.append(candidate)
|
|
109
|
+
|
|
110
|
+
if "CPUExecutionProvider" not in providers:
|
|
111
|
+
providers.append("CPUExecutionProvider")
|
|
112
|
+
|
|
113
|
+
return providers
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# ============================================================
|
|
117
|
+
# Provider Abstraction
|
|
118
|
+
# ============================================================
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class EmbeddingProvider(ABC):
|
|
122
|
+
"""Abstract base class for embedding providers."""
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
async def initialize(self) -> None:
|
|
126
|
+
"""Initialize the provider (load model, create client, etc.)."""
|
|
127
|
+
|
|
128
|
+
@abstractmethod
|
|
129
|
+
async def embed(self, texts: list[str]) -> list[list[float]]:
|
|
130
|
+
"""Generate embeddings for a batch of texts."""
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def dimensions(self) -> int:
|
|
134
|
+
"""Return the embedding dimensionality."""
|
|
135
|
+
|
|
136
|
+
async def shutdown(self) -> None:
|
|
137
|
+
"""Clean up resources."""
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# ============================================================
|
|
141
|
+
# api_openai Provider
|
|
142
|
+
# ============================================================
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
146
|
+
"""OpenAI-compatible embedding API provider."""
|
|
147
|
+
|
|
148
|
+
def __init__(self, api_key: str, api_url: str, model: str, timeout: int):
|
|
149
|
+
self._api_key = api_key
|
|
150
|
+
self._api_url = api_url
|
|
151
|
+
self._model = model or "text-embedding-3-small"
|
|
152
|
+
self._timeout = timeout
|
|
153
|
+
self._client: httpx.AsyncClient | None = None
|
|
154
|
+
self._dimensions = int(os.environ.get("EMBEDDING_DIMENSIONS", "1536"))
|
|
155
|
+
|
|
156
|
+
async def initialize(self) -> None:
|
|
157
|
+
if not self._api_key:
|
|
158
|
+
raise ValueError("EMBEDDING_API_KEY is required for api_openai provider")
|
|
159
|
+
self._client = httpx.AsyncClient(timeout=self._timeout)
|
|
160
|
+
logger.info(
|
|
161
|
+
"OpenAI embedding provider initialized (model=%s, url=%s)",
|
|
162
|
+
self._model,
|
|
163
|
+
self._api_url,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
async def embed(self, texts: list[str]) -> list[list[float]]:
|
|
167
|
+
if not self._client:
|
|
168
|
+
raise RuntimeError("Provider not initialized")
|
|
169
|
+
|
|
170
|
+
response = await self._client.post(
|
|
171
|
+
self._api_url,
|
|
172
|
+
headers={
|
|
173
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
174
|
+
"Content-Type": "application/json",
|
|
175
|
+
},
|
|
176
|
+
json={"model": self._model, "input": texts},
|
|
177
|
+
)
|
|
178
|
+
response.raise_for_status()
|
|
179
|
+
|
|
180
|
+
data = response.json()
|
|
181
|
+
embeddings = [item["embedding"] for item in data["data"]]
|
|
182
|
+
|
|
183
|
+
# Update dimensions from actual response
|
|
184
|
+
if embeddings:
|
|
185
|
+
self._dimensions = len(embeddings[0])
|
|
186
|
+
|
|
187
|
+
# L2-normalize for consistent cosine similarity via dot product
|
|
188
|
+
result = []
|
|
189
|
+
for emb in embeddings:
|
|
190
|
+
vec = np.array(emb, dtype=np.float32)
|
|
191
|
+
norm = np.linalg.norm(vec)
|
|
192
|
+
if norm > 1e-9:
|
|
193
|
+
vec = vec / norm
|
|
194
|
+
result.append(vec.tolist())
|
|
195
|
+
|
|
196
|
+
return result
|
|
197
|
+
|
|
198
|
+
def dimensions(self) -> int:
|
|
199
|
+
return self._dimensions
|
|
200
|
+
|
|
201
|
+
async def shutdown(self) -> None:
|
|
202
|
+
if self._client:
|
|
203
|
+
await self._client.aclose()
|
|
204
|
+
self._client = None
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# ============================================================
|
|
208
|
+
# onnx_miniml Provider
|
|
209
|
+
# ============================================================
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class OnnxMiniLMProvider(EmbeddingProvider):
|
|
213
|
+
"""Local all-MiniLM-L6-v2 ONNX embedding provider."""
|
|
214
|
+
|
|
215
|
+
def __init__(self, model_dir: str):
|
|
216
|
+
self._model_dir = model_dir
|
|
217
|
+
self._session = None
|
|
218
|
+
self._tokenizer = None
|
|
219
|
+
self._lock = asyncio.Lock()
|
|
220
|
+
|
|
221
|
+
async def initialize(self) -> None:
|
|
222
|
+
try:
|
|
223
|
+
import onnxruntime as ort
|
|
224
|
+
from tokenizers import Tokenizer
|
|
225
|
+
except ImportError:
|
|
226
|
+
raise ImportError(
|
|
227
|
+
"onnx_miniml provider requires: uv pip install onnxruntime tokenizers\n"
|
|
228
|
+
"Or: uv pip install cloto-mcp-embedding[onnx]"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
model_path = os.path.join(self._model_dir, "model.onnx")
|
|
232
|
+
tokenizer_path = os.path.join(self._model_dir, "tokenizer.json")
|
|
233
|
+
|
|
234
|
+
# Auto-download model if missing
|
|
235
|
+
if not os.path.exists(model_path) or not os.path.exists(tokenizer_path):
|
|
236
|
+
logger.info("ONNX model not found, downloading automatically...")
|
|
237
|
+
try:
|
|
238
|
+
from cembedding.download_model import download
|
|
239
|
+
|
|
240
|
+
if not download():
|
|
241
|
+
raise FileNotFoundError(f"Failed to download ONNX model to {self._model_dir}")
|
|
242
|
+
except ImportError:
|
|
243
|
+
raise FileNotFoundError(
|
|
244
|
+
f"ONNX model not found at {model_path}. "
|
|
245
|
+
f"Download with: python -m cembedding.download_model"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
providers = _select_ort_providers()
|
|
249
|
+
|
|
250
|
+
self._session = ort.InferenceSession(model_path, providers=providers)
|
|
251
|
+
self._tokenizer = Tokenizer.from_file(tokenizer_path)
|
|
252
|
+
# MiniLM positional embeddings cap at 512 — clamp ONNX_MAX_SEQ_LEN.
|
|
253
|
+
miniml_seq_len = min(ONNX_MAX_SEQ_LEN, 512)
|
|
254
|
+
if miniml_seq_len < ONNX_MAX_SEQ_LEN:
|
|
255
|
+
logger.warning(
|
|
256
|
+
"MiniLM max_position=512, clamping ONNX_MAX_SEQ_LEN=%d to 512",
|
|
257
|
+
ONNX_MAX_SEQ_LEN,
|
|
258
|
+
)
|
|
259
|
+
self._tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=miniml_seq_len)
|
|
260
|
+
self._tokenizer.enable_truncation(max_length=miniml_seq_len)
|
|
261
|
+
|
|
262
|
+
logger.info(
|
|
263
|
+
"ONNX MiniLM provider initialized (dir=%s, seq_len=%d, requested=%s, active=%s)",
|
|
264
|
+
self._model_dir,
|
|
265
|
+
miniml_seq_len,
|
|
266
|
+
providers,
|
|
267
|
+
self._session.get_providers(),
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
async def embed(self, texts: list[str]) -> list[list[float]]:
|
|
271
|
+
if not self._session or not self._tokenizer:
|
|
272
|
+
raise RuntimeError("Provider not initialized")
|
|
273
|
+
|
|
274
|
+
async with self._lock:
|
|
275
|
+
return await asyncio.get_event_loop().run_in_executor(None, self._embed_sync, texts)
|
|
276
|
+
|
|
277
|
+
def _embed_sync(self, texts: list[str]) -> list[list[float]]:
|
|
278
|
+
"""Synchronous embedding (run in executor to avoid blocking)."""
|
|
279
|
+
encodings = self._tokenizer.encode_batch(texts)
|
|
280
|
+
|
|
281
|
+
input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
|
|
282
|
+
attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
|
|
283
|
+
token_type_ids = np.array([e.type_ids for e in encodings], dtype=np.int64)
|
|
284
|
+
|
|
285
|
+
outputs = self._session.run(
|
|
286
|
+
None,
|
|
287
|
+
{
|
|
288
|
+
"input_ids": input_ids,
|
|
289
|
+
"attention_mask": attention_mask,
|
|
290
|
+
"token_type_ids": token_type_ids,
|
|
291
|
+
},
|
|
292
|
+
)
|
|
293
|
+
token_embeddings = outputs[0] # (batch, seq_len, hidden_dim)
|
|
294
|
+
|
|
295
|
+
# Mean pooling + L2 normalization
|
|
296
|
+
mask_expanded = np.expand_dims(attention_mask, -1).astype(np.float32)
|
|
297
|
+
sum_embeddings = np.sum(token_embeddings * mask_expanded, axis=1)
|
|
298
|
+
sum_mask = np.clip(np.sum(mask_expanded, axis=1), a_min=1e-9, a_max=None)
|
|
299
|
+
mean_pooled = sum_embeddings / sum_mask
|
|
300
|
+
|
|
301
|
+
norms = np.linalg.norm(mean_pooled, axis=1, keepdims=True)
|
|
302
|
+
norms = np.clip(norms, a_min=1e-9, a_max=None)
|
|
303
|
+
normalized = mean_pooled / norms
|
|
304
|
+
|
|
305
|
+
return normalized.tolist()
|
|
306
|
+
|
|
307
|
+
def dimensions(self) -> int:
|
|
308
|
+
return 384
|
|
309
|
+
|
|
310
|
+
async def shutdown(self) -> None:
|
|
311
|
+
self._session = None
|
|
312
|
+
self._tokenizer = None
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# ============================================================
|
|
316
|
+
# Platform detection helpers
|
|
317
|
+
# ============================================================
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _is_apple_silicon() -> bool:
|
|
321
|
+
return sys.platform == "darwin" and _platform.machine() == "arm64"
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _mlx_available() -> bool:
|
|
325
|
+
try:
|
|
326
|
+
import mlx.core # noqa: F401
|
|
327
|
+
import mlx_embeddings # noqa: F401
|
|
328
|
+
|
|
329
|
+
return True
|
|
330
|
+
except ImportError:
|
|
331
|
+
return False
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
# ============================================================
|
|
335
|
+
# CoreML-aware session factory (shared by ONNX providers)
|
|
336
|
+
# ============================================================
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def _create_ort_session(model_path: str, providers: list):
|
|
340
|
+
"""Create InferenceSession with CoreML-aware 3-stage fallback.
|
|
341
|
+
|
|
342
|
+
Stage 1: CoreML with MLProgram + dynamic shapes (ort 1.18+).
|
|
343
|
+
Stage 2: CoreML without provider_options (ort version mismatch tolerance).
|
|
344
|
+
Stage 3: CPU only (guaranteed fallback).
|
|
345
|
+
"""
|
|
346
|
+
import onnxruntime as ort
|
|
347
|
+
|
|
348
|
+
if "CoreMLExecutionProvider" in providers:
|
|
349
|
+
rest = [p for p in providers if p != "CoreMLExecutionProvider"]
|
|
350
|
+
providers_with_opts = [
|
|
351
|
+
(
|
|
352
|
+
"CoreMLExecutionProvider",
|
|
353
|
+
{
|
|
354
|
+
"ModelFormat": "MLProgram",
|
|
355
|
+
"MLComputeUnits": "ALL",
|
|
356
|
+
"RequireStaticInputShapes": "0",
|
|
357
|
+
},
|
|
358
|
+
),
|
|
359
|
+
*rest,
|
|
360
|
+
]
|
|
361
|
+
try:
|
|
362
|
+
return ort.InferenceSession(model_path, providers=providers_with_opts)
|
|
363
|
+
except Exception as e:
|
|
364
|
+
logger.warning("CoreML init with provider_options failed, retrying without options: %s", e)
|
|
365
|
+
try:
|
|
366
|
+
return ort.InferenceSession(model_path, providers=providers)
|
|
367
|
+
except Exception as e:
|
|
368
|
+
logger.warning("CoreML plain init failed, falling back to CPU-only: %s", e)
|
|
369
|
+
return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
|
370
|
+
|
|
371
|
+
return ort.InferenceSession(model_path, providers=providers)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
# ============================================================
|
|
375
|
+
# mlx_bge_m3 Provider (macOS Apple Silicon only)
|
|
376
|
+
# ============================================================
|
|
377
|
+
|
|
378
|
+
MLX_BGE_M3_REPO = "mlx-community/bge-m3-mlx-fp16"
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
class MlxBgeM3Provider(EmbeddingProvider):
|
|
382
|
+
"""Apple Silicon MLX bge-m3 provider (fp16, Metal/ANE accelerated).
|
|
383
|
+
|
|
384
|
+
Uses mlx-embeddings to run BAAI/bge-m3 natively on Metal GPU.
|
|
385
|
+
macOS ARM only — auto_bge_m3 falls back to OnnxBgeM3Provider on other platforms.
|
|
386
|
+
"""
|
|
387
|
+
|
|
388
|
+
def __init__(self, model_path: str = MLX_BGE_M3_REPO):
|
|
389
|
+
# model_path: HuggingFace repo ID or absolute local directory
|
|
390
|
+
self._model_path = model_path
|
|
391
|
+
self._model = None
|
|
392
|
+
self._tokenizer = None
|
|
393
|
+
self._lock = asyncio.Lock()
|
|
394
|
+
|
|
395
|
+
async def initialize(self) -> None:
|
|
396
|
+
try:
|
|
397
|
+
from mlx_embeddings.utils import load as mlx_load
|
|
398
|
+
except ImportError:
|
|
399
|
+
raise ImportError("mlx_bge_m3 requires: uv pip install mlx-embeddings")
|
|
400
|
+
|
|
401
|
+
self._model, self._tokenizer = await asyncio.get_event_loop().run_in_executor(None, mlx_load, self._model_path)
|
|
402
|
+
logger.info("MLX BGE-M3 provider initialized (path=%s)", self._model_path)
|
|
403
|
+
|
|
404
|
+
async def embed(self, texts: list[str]) -> list[list[float]]:
|
|
405
|
+
if not self._model or not self._tokenizer:
|
|
406
|
+
raise RuntimeError("Provider not initialized")
|
|
407
|
+
|
|
408
|
+
async with self._lock:
|
|
409
|
+
return await asyncio.get_event_loop().run_in_executor(None, self._embed_sync, texts)
|
|
410
|
+
|
|
411
|
+
def _embed_sync(self, texts: list[str]) -> list[list[float]]:
|
|
412
|
+
import mlx.core as mx
|
|
413
|
+
|
|
414
|
+
inputs = self._tokenizer.batch_encode_plus(
|
|
415
|
+
texts,
|
|
416
|
+
return_tensors="mlx",
|
|
417
|
+
padding=True,
|
|
418
|
+
truncation=True,
|
|
419
|
+
max_length=min(ONNX_MAX_SEQ_LEN, 8192),
|
|
420
|
+
)
|
|
421
|
+
outputs = self._model(
|
|
422
|
+
inputs["input_ids"],
|
|
423
|
+
attention_mask=inputs.get("attention_mask"),
|
|
424
|
+
)
|
|
425
|
+
embs = outputs.text_embeds # already L2-normalized
|
|
426
|
+
mx.eval(embs)
|
|
427
|
+
return embs.tolist()
|
|
428
|
+
|
|
429
|
+
def dimensions(self) -> int:
|
|
430
|
+
return 1024
|
|
431
|
+
|
|
432
|
+
async def shutdown(self) -> None:
|
|
433
|
+
self._model = None
|
|
434
|
+
self._tokenizer = None
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
# ============================================================
|
|
438
|
+
# onnx_bge_m3 Provider
|
|
439
|
+
# ============================================================
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class OnnxBgeM3Provider(EmbeddingProvider):
|
|
443
|
+
"""Local BAAI/bge-m3 ONNX embedding provider (int8 quantized).
|
|
444
|
+
|
|
445
|
+
CLS-token pooling. 1024-dim dense output, 8K context, 100+ languages.
|
|
446
|
+
Based on XLM-RoBERTa — no token_type_ids.
|
|
447
|
+
"""
|
|
448
|
+
|
|
449
|
+
def __init__(self, model_dir: str):
|
|
450
|
+
self._model_dir = model_dir
|
|
451
|
+
self._session = None
|
|
452
|
+
self._tokenizer = None
|
|
453
|
+
self._lock = asyncio.Lock()
|
|
454
|
+
|
|
455
|
+
async def initialize(self) -> None:
|
|
456
|
+
try:
|
|
457
|
+
import onnxruntime # noqa: F401
|
|
458
|
+
from tokenizers import Tokenizer
|
|
459
|
+
except ImportError:
|
|
460
|
+
raise ImportError("onnx_bge_m3 provider requires: uv pip install onnxruntime tokenizers")
|
|
461
|
+
|
|
462
|
+
model_dir_abs = os.path.abspath(self._model_dir)
|
|
463
|
+
model_path = os.path.join(model_dir_abs, "model.onnx")
|
|
464
|
+
tokenizer_path = os.path.join(model_dir_abs, "tokenizer.json")
|
|
465
|
+
|
|
466
|
+
if not os.path.exists(model_path) or not os.path.exists(tokenizer_path):
|
|
467
|
+
logger.info("BGE-M3 ONNX model not found, downloading...")
|
|
468
|
+
try:
|
|
469
|
+
from cembedding.download_model import download_bge_m3
|
|
470
|
+
|
|
471
|
+
if not download_bge_m3(model_dir_abs):
|
|
472
|
+
raise FileNotFoundError(f"Failed to download BGE-M3 model to {model_dir_abs}")
|
|
473
|
+
except ImportError:
|
|
474
|
+
raise FileNotFoundError(
|
|
475
|
+
f"ONNX model not found at {model_path}. "
|
|
476
|
+
f"Download with: python -m cembedding.download_model --model bge-m3"
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
providers = _select_ort_providers()
|
|
480
|
+
self._session = _create_ort_session(model_path, providers)
|
|
481
|
+
self._tokenizer = Tokenizer.from_file(tokenizer_path)
|
|
482
|
+
bge_seq_len = min(ONNX_MAX_SEQ_LEN, 8192)
|
|
483
|
+
# XLM-RoBERTa pad token is <pad> (id=1)
|
|
484
|
+
self._tokenizer.enable_padding(pad_id=1, pad_token="<pad>", length=bge_seq_len)
|
|
485
|
+
self._tokenizer.enable_truncation(max_length=bge_seq_len)
|
|
486
|
+
|
|
487
|
+
logger.info(
|
|
488
|
+
"ONNX BGE-M3 provider initialized (dir=%s, seq_len=%d, requested=%s, active=%s)",
|
|
489
|
+
self._model_dir,
|
|
490
|
+
bge_seq_len,
|
|
491
|
+
providers,
|
|
492
|
+
self._session.get_providers(),
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
async def embed(self, texts: list[str]) -> list[list[float]]:
|
|
496
|
+
if not self._session or not self._tokenizer:
|
|
497
|
+
raise RuntimeError("Provider not initialized")
|
|
498
|
+
|
|
499
|
+
async with self._lock:
|
|
500
|
+
return await asyncio.get_event_loop().run_in_executor(None, self._embed_sync, texts)
|
|
501
|
+
|
|
502
|
+
def _embed_sync(self, texts: list[str]) -> list[list[float]]:
|
|
503
|
+
"""Synchronous embedding with CLS-token pooling."""
|
|
504
|
+
encodings = self._tokenizer.encode_batch(texts)
|
|
505
|
+
|
|
506
|
+
input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
|
|
507
|
+
attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
|
|
508
|
+
|
|
509
|
+
input_names = [inp.name for inp in self._session.get_inputs()]
|
|
510
|
+
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
511
|
+
if "token_type_ids" in input_names:
|
|
512
|
+
inputs["token_type_ids"] = np.zeros_like(input_ids)
|
|
513
|
+
|
|
514
|
+
outputs = self._session.run(None, inputs)
|
|
515
|
+
token_embeddings = outputs[0] # (batch, seq_len, 1024)
|
|
516
|
+
|
|
517
|
+
# CLS-token pooling: first token of each sequence
|
|
518
|
+
cls_embs = token_embeddings[:, 0, :].astype(np.float32, copy=False)
|
|
519
|
+
|
|
520
|
+
# L2 normalization
|
|
521
|
+
norms = np.linalg.norm(cls_embs, axis=1, keepdims=True)
|
|
522
|
+
norms = np.clip(norms, a_min=1e-9, a_max=None)
|
|
523
|
+
normalized = cls_embs / norms
|
|
524
|
+
|
|
525
|
+
return normalized.tolist()
|
|
526
|
+
|
|
527
|
+
def dimensions(self) -> int:
|
|
528
|
+
return 1024
|
|
529
|
+
|
|
530
|
+
async def shutdown(self) -> None:
|
|
531
|
+
self._session = None
|
|
532
|
+
self._tokenizer = None
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
# ============================================================
|
|
536
|
+
# onnx_jina_v5_nano Provider
|
|
537
|
+
# ============================================================
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
class OnnxJinaV5NanoProvider(EmbeddingProvider):
|
|
541
|
+
"""Local jina-embeddings-v5-text-nano ONNX embedding provider.
|
|
542
|
+
|
|
543
|
+
Uses Last-Token pooling (different from MiniLM's mean pooling).
|
|
544
|
+
768-dim output, 8K context, retrieval-optimized merged LoRA.
|
|
545
|
+
"""
|
|
546
|
+
|
|
547
|
+
def __init__(self, model_dir: str):
|
|
548
|
+
self._model_dir = model_dir
|
|
549
|
+
self._session = None
|
|
550
|
+
self._tokenizer = None
|
|
551
|
+
self._lock = asyncio.Lock()
|
|
552
|
+
|
|
553
|
+
async def initialize(self) -> None:
|
|
554
|
+
try:
|
|
555
|
+
import onnxruntime # noqa: F401 # availability check — session is created in _create_session_with_fallback
|
|
556
|
+
from tokenizers import Tokenizer
|
|
557
|
+
except ImportError:
|
|
558
|
+
raise ImportError(
|
|
559
|
+
"onnx_jina_v5_nano provider requires: uv pip install onnxruntime tokenizers\n"
|
|
560
|
+
"Or: uv pip install cloto-mcp-embedding[onnx]"
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
model_path = os.path.join(self._model_dir, "model.onnx")
|
|
564
|
+
tokenizer_path = os.path.join(self._model_dir, "tokenizer.json")
|
|
565
|
+
|
|
566
|
+
# Auto-download model if missing
|
|
567
|
+
if not os.path.exists(model_path) or not os.path.exists(tokenizer_path):
|
|
568
|
+
logger.info("Jina-v5-nano ONNX model not found, downloading...")
|
|
569
|
+
try:
|
|
570
|
+
from cembedding.download_model import download_jina_v5_nano
|
|
571
|
+
|
|
572
|
+
if not download_jina_v5_nano(self._model_dir):
|
|
573
|
+
raise FileNotFoundError(f"Failed to download model to {self._model_dir}")
|
|
574
|
+
except ImportError:
|
|
575
|
+
raise FileNotFoundError(
|
|
576
|
+
f"ONNX model not found at {model_path}. "
|
|
577
|
+
f"Download with: python -m cembedding.download_model --model jina-v5-nano"
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
providers = _select_ort_providers()
|
|
581
|
+
self._session = self._create_session_with_fallback(model_path, providers)
|
|
582
|
+
self._tokenizer = Tokenizer.from_file(tokenizer_path)
|
|
583
|
+
# jina-v5-nano supports up to 8K context via RoPE.
|
|
584
|
+
self._tokenizer.enable_padding(pad_id=0, pad_token="<pad>", length=ONNX_MAX_SEQ_LEN)
|
|
585
|
+
self._tokenizer.enable_truncation(max_length=ONNX_MAX_SEQ_LEN)
|
|
586
|
+
|
|
587
|
+
logger.info(
|
|
588
|
+
"ONNX Jina-v5-nano provider initialized (dir=%s, seq_len=%d, requested=%s, active=%s)",
|
|
589
|
+
self._model_dir,
|
|
590
|
+
ONNX_MAX_SEQ_LEN,
|
|
591
|
+
providers,
|
|
592
|
+
self._session.get_providers(),
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
def _create_session_with_fallback(self, model_path: str, providers: list):
|
|
596
|
+
return _create_ort_session(model_path, providers)
|
|
597
|
+
|
|
598
|
+
async def embed(self, texts: list[str]) -> list[list[float]]:
|
|
599
|
+
if not self._session or not self._tokenizer:
|
|
600
|
+
raise RuntimeError("Provider not initialized")
|
|
601
|
+
|
|
602
|
+
async with self._lock:
|
|
603
|
+
return await asyncio.get_event_loop().run_in_executor(None, self._embed_sync, texts)
|
|
604
|
+
|
|
605
|
+
def _embed_sync(self, texts: list[str]) -> list[list[float]]:
|
|
606
|
+
"""Synchronous embedding with Last-Token pooling."""
|
|
607
|
+
encodings = self._tokenizer.encode_batch(texts)
|
|
608
|
+
|
|
609
|
+
input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
|
|
610
|
+
attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
|
|
611
|
+
|
|
612
|
+
# jina-v5-nano may not use token_type_ids — check model inputs
|
|
613
|
+
input_names = [inp.name for inp in self._session.get_inputs()]
|
|
614
|
+
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
615
|
+
if "token_type_ids" in input_names:
|
|
616
|
+
inputs["token_type_ids"] = np.zeros_like(input_ids)
|
|
617
|
+
|
|
618
|
+
outputs = self._session.run(None, inputs)
|
|
619
|
+
token_embeddings = outputs[0] # (batch, seq_len, hidden_dim=768)
|
|
620
|
+
|
|
621
|
+
# Last-Token pooling (vectorized): gather embedding at the last non-padding token per row.
|
|
622
|
+
last_indices = np.maximum(np.sum(attention_mask, axis=1) - 1, 0).astype(np.int64)
|
|
623
|
+
batch_indices = np.arange(token_embeddings.shape[0])
|
|
624
|
+
last_token_embs = token_embeddings[batch_indices, last_indices].astype(np.float32, copy=False)
|
|
625
|
+
|
|
626
|
+
# L2 normalization
|
|
627
|
+
norms = np.linalg.norm(last_token_embs, axis=1, keepdims=True)
|
|
628
|
+
norms = np.clip(norms, a_min=1e-9, a_max=None)
|
|
629
|
+
normalized = last_token_embs / norms
|
|
630
|
+
|
|
631
|
+
return normalized.tolist()
|
|
632
|
+
|
|
633
|
+
def dimensions(self) -> int:
|
|
634
|
+
return 768
|
|
635
|
+
|
|
636
|
+
async def shutdown(self) -> None:
|
|
637
|
+
self._session = None
|
|
638
|
+
self._tokenizer = None
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
# ============================================================
|
|
642
|
+
# Vector Index (v0.2.0)
|
|
643
|
+
# ============================================================
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
class VectorIndex:
|
|
647
|
+
"""Persistent vector index with in-memory search.
|
|
648
|
+
|
|
649
|
+
Stores vectors in SQLite for durability, loads into memory for fast
|
|
650
|
+
brute-force dot product search. Namespaced to support multiple consumers.
|
|
651
|
+
"""
|
|
652
|
+
|
|
653
|
+
def __init__(self, db_path: str):
|
|
654
|
+
self._db_path = db_path
|
|
655
|
+
self._db = None
|
|
656
|
+
# In-memory index: {namespace: {item_id: np.array(float32)}}
|
|
657
|
+
self._index: dict[str, dict[str, np.ndarray]] = {}
|
|
658
|
+
|
|
659
|
+
async def initialize(self) -> None:
|
|
660
|
+
import aiosqlite
|
|
661
|
+
|
|
662
|
+
db_dir = os.path.dirname(self._db_path)
|
|
663
|
+
if db_dir:
|
|
664
|
+
os.makedirs(db_dir, exist_ok=True)
|
|
665
|
+
|
|
666
|
+
self._db = await aiosqlite.connect(self._db_path)
|
|
667
|
+
await self._db.execute("PRAGMA journal_mode=WAL")
|
|
668
|
+
await self._db.execute("PRAGMA synchronous=NORMAL")
|
|
669
|
+
await self._db.executescript(
|
|
670
|
+
"""
|
|
671
|
+
CREATE TABLE IF NOT EXISTS vectors (
|
|
672
|
+
namespace TEXT NOT NULL,
|
|
673
|
+
item_id TEXT NOT NULL,
|
|
674
|
+
vector BLOB NOT NULL,
|
|
675
|
+
created_at TEXT DEFAULT (datetime('now')),
|
|
676
|
+
PRIMARY KEY (namespace, item_id)
|
|
677
|
+
);
|
|
678
|
+
CREATE INDEX IF NOT EXISTS idx_vectors_ns ON vectors (namespace);
|
|
679
|
+
"""
|
|
680
|
+
)
|
|
681
|
+
await self._db.commit()
|
|
682
|
+
|
|
683
|
+
# Load all vectors into memory
|
|
684
|
+
rows = await self._db.execute_fetchall("SELECT namespace, item_id, vector FROM vectors")
|
|
685
|
+
for ns, item_id, blob in rows:
|
|
686
|
+
if ns not in self._index:
|
|
687
|
+
self._index[ns] = {}
|
|
688
|
+
self._index[ns][item_id] = np.frombuffer(blob, dtype=np.float32).copy()
|
|
689
|
+
|
|
690
|
+
total = sum(len(v) for v in self._index.values())
|
|
691
|
+
logger.info("VectorIndex loaded: %d vectors across %d namespaces", total, len(self._index))
|
|
692
|
+
|
|
693
|
+
async def index(self, namespace: str, items: list[dict], provider: "EmbeddingProvider") -> int:
|
|
694
|
+
"""Index items. Each item has 'id' and 'text'. Returns count indexed."""
|
|
695
|
+
if not self._db:
|
|
696
|
+
raise RuntimeError("VectorIndex not initialized")
|
|
697
|
+
|
|
698
|
+
texts = [item["text"] for item in items]
|
|
699
|
+
embeddings = await provider.embed(texts)
|
|
700
|
+
|
|
701
|
+
if namespace not in self._index:
|
|
702
|
+
self._index[namespace] = {}
|
|
703
|
+
|
|
704
|
+
indexed = 0
|
|
705
|
+
for item, emb in zip(items, embeddings):
|
|
706
|
+
item_id = item["id"]
|
|
707
|
+
vec = np.array(emb, dtype=np.float32)
|
|
708
|
+
blob = struct.pack(f"<{len(vec)}f", *vec)
|
|
709
|
+
|
|
710
|
+
await self._db.execute(
|
|
711
|
+
"INSERT OR REPLACE INTO vectors (namespace, item_id, vector) VALUES (?, ?, ?)",
|
|
712
|
+
(namespace, item_id, blob),
|
|
713
|
+
)
|
|
714
|
+
self._index[namespace][item_id] = vec
|
|
715
|
+
indexed += 1
|
|
716
|
+
|
|
717
|
+
await self._db.commit()
|
|
718
|
+
return indexed
|
|
719
|
+
|
|
720
|
+
async def search(
|
|
721
|
+
self,
|
|
722
|
+
namespace: str,
|
|
723
|
+
query: str,
|
|
724
|
+
limit: int,
|
|
725
|
+
min_similarity: float,
|
|
726
|
+
provider: "EmbeddingProvider",
|
|
727
|
+
) -> list[dict]:
|
|
728
|
+
"""Search for similar vectors. Returns [{id, score}, ...] sorted by score desc."""
|
|
729
|
+
ns_index = self._index.get(namespace)
|
|
730
|
+
if not ns_index:
|
|
731
|
+
return []
|
|
732
|
+
|
|
733
|
+
embeddings = await provider.embed([query])
|
|
734
|
+
if not embeddings or not embeddings[0]:
|
|
735
|
+
return []
|
|
736
|
+
|
|
737
|
+
query_vec = np.array(embeddings[0], dtype=np.float32)
|
|
738
|
+
query_dim = len(query_vec)
|
|
739
|
+
|
|
740
|
+
candidates = []
|
|
741
|
+
for item_id, vec in ns_index.items():
|
|
742
|
+
if len(vec) != query_dim:
|
|
743
|
+
continue
|
|
744
|
+
sim = float(np.dot(query_vec, vec))
|
|
745
|
+
if sim >= min_similarity:
|
|
746
|
+
candidates.append((sim, item_id))
|
|
747
|
+
|
|
748
|
+
# Top-K via heap
|
|
749
|
+
import heapq
|
|
750
|
+
|
|
751
|
+
top_k = heapq.nlargest(limit, candidates, key=lambda x: x[0])
|
|
752
|
+
return [{"id": item_id, "score": round(score, 4)} for score, item_id in top_k]
|
|
753
|
+
|
|
754
|
+
async def remove(self, namespace: str, ids: list[str]) -> int:
|
|
755
|
+
"""Remove items from index. Returns count removed."""
|
|
756
|
+
if not self._db:
|
|
757
|
+
raise RuntimeError("VectorIndex not initialized")
|
|
758
|
+
|
|
759
|
+
removed = 0
|
|
760
|
+
ns_index = self._index.get(namespace, {})
|
|
761
|
+
for item_id in ids:
|
|
762
|
+
cursor = await self._db.execute(
|
|
763
|
+
"DELETE FROM vectors WHERE namespace = ? AND item_id = ?",
|
|
764
|
+
(namespace, item_id),
|
|
765
|
+
)
|
|
766
|
+
if cursor.rowcount > 0:
|
|
767
|
+
removed += 1
|
|
768
|
+
ns_index.pop(item_id, None)
|
|
769
|
+
|
|
770
|
+
await self._db.commit()
|
|
771
|
+
return removed
|
|
772
|
+
|
|
773
|
+
async def purge_namespace(self, namespace: str) -> int:
|
|
774
|
+
"""Remove all vectors in a namespace. Returns count removed."""
|
|
775
|
+
if not self._db:
|
|
776
|
+
raise RuntimeError("VectorIndex not initialized")
|
|
777
|
+
|
|
778
|
+
cursor = await self._db.execute("DELETE FROM vectors WHERE namespace = ?", (namespace,))
|
|
779
|
+
await self._db.commit()
|
|
780
|
+
removed = len(self._index.pop(namespace, {}))
|
|
781
|
+
return max(cursor.rowcount, removed)
|
|
782
|
+
|
|
783
|
+
async def count(self, namespace: str) -> int:
|
|
784
|
+
"""Count vectors in a namespace."""
|
|
785
|
+
return len(self._index.get(namespace, {}))
|
|
786
|
+
|
|
787
|
+
async def shutdown(self) -> None:
|
|
788
|
+
if self._db:
|
|
789
|
+
await self._db.close()
|
|
790
|
+
self._db = None
|
|
791
|
+
self._index.clear()
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
_vector_index: VectorIndex | None = None
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
# ============================================================
|
|
798
|
+
# Provider Factory
|
|
799
|
+
# ============================================================
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
def _resolve_model_dir(provider_key: str) -> str:
|
|
803
|
+
"""Resolve model directory for a provider key, respecting CLOTO_PROJECT_DIR."""
|
|
804
|
+
rel = _MODEL_DIRS.get(provider_key, "data/models/bge-m3")
|
|
805
|
+
if _project_dir and not os.path.isabs(rel):
|
|
806
|
+
return os.path.join(_project_dir, rel)
|
|
807
|
+
return rel
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
def create_provider() -> EmbeddingProvider:
|
|
811
|
+
"""Create an embedding provider based on configuration."""
|
|
812
|
+
if EMBEDDING_PROVIDER == "api_openai":
|
|
813
|
+
return OpenAIEmbeddingProvider(
|
|
814
|
+
api_key=EMBEDDING_API_KEY,
|
|
815
|
+
api_url=EMBEDDING_API_URL,
|
|
816
|
+
model=EMBEDDING_MODEL,
|
|
817
|
+
timeout=EMBEDDING_TIMEOUT,
|
|
818
|
+
)
|
|
819
|
+
elif EMBEDDING_PROVIDER == "onnx_miniml":
|
|
820
|
+
return OnnxMiniLMProvider(model_dir=ONNX_MODEL_DIR)
|
|
821
|
+
elif EMBEDDING_PROVIDER == "onnx_jina_v5_nano":
|
|
822
|
+
return OnnxJinaV5NanoProvider(model_dir=ONNX_MODEL_DIR)
|
|
823
|
+
elif EMBEDDING_PROVIDER == "onnx_bge_m3":
|
|
824
|
+
return OnnxBgeM3Provider(model_dir=ONNX_MODEL_DIR)
|
|
825
|
+
elif EMBEDDING_PROVIDER == "mlx_bge_m3":
|
|
826
|
+
mlx_path = os.environ.get("MLX_MODEL_DIR", "")
|
|
827
|
+
if not mlx_path:
|
|
828
|
+
local = _resolve_model_dir("mlx_bge_m3")
|
|
829
|
+
mlx_path = local if os.path.isdir(local) else MLX_BGE_M3_REPO
|
|
830
|
+
return MlxBgeM3Provider(model_path=mlx_path)
|
|
831
|
+
elif EMBEDDING_PROVIDER == "auto_bge_m3":
|
|
832
|
+
if _is_apple_silicon() and _mlx_available():
|
|
833
|
+
mlx_path = os.environ.get("MLX_MODEL_DIR", "") or MLX_BGE_M3_REPO
|
|
834
|
+
logger.info("auto_bge_m3: Apple Silicon detected, using MLX provider")
|
|
835
|
+
return MlxBgeM3Provider(model_path=mlx_path)
|
|
836
|
+
logger.info("auto_bge_m3: falling back to ONNX CPU provider")
|
|
837
|
+
return OnnxBgeM3Provider(model_dir=ONNX_MODEL_DIR)
|
|
838
|
+
else:
|
|
839
|
+
raise ValueError(
|
|
840
|
+
f"Unknown embedding provider: {EMBEDDING_PROVIDER}. "
|
|
841
|
+
f"Supported: api_openai, onnx_miniml, onnx_jina_v5_nano, onnx_bge_m3, "
|
|
842
|
+
f"mlx_bge_m3, auto_bge_m3"
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
# ============================================================
|
|
847
|
+
# HTTP Endpoint (for CPersona inter-server communication)
|
|
848
|
+
# ============================================================
|
|
849
|
+
|
|
850
|
+
_provider: EmbeddingProvider | None = None
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
async def handle_embed(request: web.Request) -> web.Response:
|
|
854
|
+
"""POST /embed — Generate embeddings for input texts."""
|
|
855
|
+
if _provider is None:
|
|
856
|
+
return web.json_response({"error": "Provider not initialized"}, status=503)
|
|
857
|
+
|
|
858
|
+
try:
|
|
859
|
+
body = await request.json()
|
|
860
|
+
except Exception:
|
|
861
|
+
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
|
862
|
+
|
|
863
|
+
texts = body.get("texts")
|
|
864
|
+
if not isinstance(texts, list) or not texts:
|
|
865
|
+
return web.json_response(
|
|
866
|
+
{"error": "'texts' must be a non-empty array of strings"},
|
|
867
|
+
status=400,
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
# Limit batch size to prevent OOM
|
|
871
|
+
if len(texts) > 100:
|
|
872
|
+
return web.json_response({"error": "Batch size exceeds limit (max 100)"}, status=400)
|
|
873
|
+
|
|
874
|
+
try:
|
|
875
|
+
embeddings = await _provider.embed(texts)
|
|
876
|
+
return web.json_response(
|
|
877
|
+
{
|
|
878
|
+
"embeddings": embeddings,
|
|
879
|
+
"dimensions": _provider.dimensions(),
|
|
880
|
+
}
|
|
881
|
+
)
|
|
882
|
+
except Exception as e:
|
|
883
|
+
logger.exception("Embedding failed")
|
|
884
|
+
return web.json_response({"error": f"Embedding failed: {e}"}, status=500)
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
async def handle_index(request: web.Request) -> web.Response:
|
|
888
|
+
"""POST /index — Index vectors for later search."""
|
|
889
|
+
if _provider is None or _vector_index is None:
|
|
890
|
+
return web.json_response({"error": "Not initialized"}, status=503)
|
|
891
|
+
|
|
892
|
+
try:
|
|
893
|
+
body = await request.json()
|
|
894
|
+
except Exception:
|
|
895
|
+
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
|
896
|
+
|
|
897
|
+
namespace = body.get("namespace", "default")
|
|
898
|
+
items = body.get("items")
|
|
899
|
+
if not isinstance(items, list) or not items:
|
|
900
|
+
return web.json_response({"error": "'items' must be a non-empty array"}, status=400)
|
|
901
|
+
if len(items) > 100:
|
|
902
|
+
return web.json_response({"error": "Batch size exceeds limit (max 100)"}, status=400)
|
|
903
|
+
|
|
904
|
+
for item in items:
|
|
905
|
+
if not isinstance(item, dict) or "id" not in item or "text" not in item:
|
|
906
|
+
return web.json_response({"error": "Each item must have 'id' and 'text'"}, status=400)
|
|
907
|
+
|
|
908
|
+
try:
|
|
909
|
+
indexed = await _vector_index.index(namespace, items, _provider)
|
|
910
|
+
return web.json_response({"ok": True, "indexed": indexed})
|
|
911
|
+
except Exception as e:
|
|
912
|
+
logger.exception("Index failed")
|
|
913
|
+
return web.json_response({"error": f"Index failed: {e}"}, status=500)
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
async def handle_search(request: web.Request) -> web.Response:
|
|
917
|
+
"""POST /search — Search indexed vectors by similarity."""
|
|
918
|
+
if _provider is None or _vector_index is None:
|
|
919
|
+
return web.json_response({"error": "Not initialized"}, status=503)
|
|
920
|
+
|
|
921
|
+
try:
|
|
922
|
+
body = await request.json()
|
|
923
|
+
except Exception:
|
|
924
|
+
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
|
925
|
+
|
|
926
|
+
namespace = body.get("namespace", "default")
|
|
927
|
+
query = body.get("query")
|
|
928
|
+
if not isinstance(query, str) or not query.strip():
|
|
929
|
+
return web.json_response({"error": "'query' must be a non-empty string"}, status=400)
|
|
930
|
+
|
|
931
|
+
limit = min(int(body.get("limit", 10)), 500)
|
|
932
|
+
min_similarity = float(body.get("min_similarity", 0.3))
|
|
933
|
+
|
|
934
|
+
try:
|
|
935
|
+
results = await _vector_index.search(namespace, query, limit, min_similarity, _provider)
|
|
936
|
+
return web.json_response({"results": results})
|
|
937
|
+
except Exception as e:
|
|
938
|
+
logger.exception("Search failed")
|
|
939
|
+
return web.json_response({"error": f"Search failed: {e}"}, status=500)
|
|
940
|
+
|
|
941
|
+
|
|
942
|
+
async def handle_remove(request: web.Request) -> web.Response:
|
|
943
|
+
"""POST /remove — Remove vectors from index."""
|
|
944
|
+
if _vector_index is None:
|
|
945
|
+
return web.json_response({"error": "Not initialized"}, status=503)
|
|
946
|
+
|
|
947
|
+
try:
|
|
948
|
+
body = await request.json()
|
|
949
|
+
except Exception:
|
|
950
|
+
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
|
951
|
+
|
|
952
|
+
namespace = body.get("namespace", "default")
|
|
953
|
+
ids = body.get("ids")
|
|
954
|
+
if not isinstance(ids, list) or not ids:
|
|
955
|
+
return web.json_response({"error": "'ids' must be a non-empty array"}, status=400)
|
|
956
|
+
|
|
957
|
+
try:
|
|
958
|
+
removed = await _vector_index.remove(namespace, ids)
|
|
959
|
+
return web.json_response({"ok": True, "removed": removed})
|
|
960
|
+
except Exception as e:
|
|
961
|
+
logger.exception("Remove failed")
|
|
962
|
+
return web.json_response({"error": f"Remove failed: {e}"}, status=500)
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
async def handle_purge(request: web.Request) -> web.Response:
|
|
966
|
+
"""POST /purge — Remove all vectors in a namespace."""
|
|
967
|
+
if _vector_index is None:
|
|
968
|
+
return web.json_response({"error": "Not initialized"}, status=503)
|
|
969
|
+
|
|
970
|
+
try:
|
|
971
|
+
body = await request.json()
|
|
972
|
+
except Exception:
|
|
973
|
+
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
|
974
|
+
|
|
975
|
+
namespace = body.get("namespace")
|
|
976
|
+
if not isinstance(namespace, str) or not namespace.strip():
|
|
977
|
+
return web.json_response({"error": "'namespace' must be a non-empty string"}, status=400)
|
|
978
|
+
|
|
979
|
+
try:
|
|
980
|
+
removed = await _vector_index.purge_namespace(namespace)
|
|
981
|
+
return web.json_response({"ok": True, "removed": removed})
|
|
982
|
+
except Exception as e:
|
|
983
|
+
logger.exception("Purge failed")
|
|
984
|
+
return web.json_response({"error": f"Purge failed: {e}"}, status=500)
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
async def run_http_server(port: int) -> None:
|
|
988
|
+
"""Run the HTTP embedding endpoint alongside MCP stdio."""
|
|
989
|
+
app = web.Application()
|
|
990
|
+
app.router.add_post("/embed", handle_embed)
|
|
991
|
+
if EMBEDDING_INDEX_ENABLED and _vector_index is not None:
|
|
992
|
+
app.router.add_post("/index", handle_index)
|
|
993
|
+
app.router.add_post("/search", handle_search)
|
|
994
|
+
app.router.add_post("/remove", handle_remove)
|
|
995
|
+
app.router.add_post("/purge", handle_purge)
|
|
996
|
+
|
|
997
|
+
runner = web.AppRunner(app, access_log=None)
|
|
998
|
+
await runner.setup()
|
|
999
|
+
site = web.TCPSite(runner, "127.0.0.1", port)
|
|
1000
|
+
await site.start()
|
|
1001
|
+
logger.info("HTTP embedding endpoint started on http://127.0.0.1:%d/embed", port)
|
|
1002
|
+
|
|
1003
|
+
try:
|
|
1004
|
+
# Block until cancelled
|
|
1005
|
+
await asyncio.Event().wait()
|
|
1006
|
+
finally:
|
|
1007
|
+
await runner.cleanup()
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
# ============================================================
|
|
1011
|
+
# MCP Server
|
|
1012
|
+
# ============================================================
|
|
1013
|
+
|
|
1014
|
+
registry = ToolRegistry("cloto-mcp-embedding")
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
@registry.tool(
|
|
1018
|
+
"embed",
|
|
1019
|
+
"Generate vector embeddings for input texts.",
|
|
1020
|
+
{
|
|
1021
|
+
"type": "object",
|
|
1022
|
+
"properties": {
|
|
1023
|
+
"texts": {
|
|
1024
|
+
"type": "array",
|
|
1025
|
+
"items": {"type": "string"},
|
|
1026
|
+
"description": "Texts to embed (batch, max 100)",
|
|
1027
|
+
}
|
|
1028
|
+
},
|
|
1029
|
+
"required": ["texts"],
|
|
1030
|
+
},
|
|
1031
|
+
)
|
|
1032
|
+
async def handle_embed_tool(arguments: dict) -> dict:
|
|
1033
|
+
if _provider is None:
|
|
1034
|
+
return {"error": "Provider not initialized"}
|
|
1035
|
+
|
|
1036
|
+
texts = arguments.get("texts", [])
|
|
1037
|
+
if not isinstance(texts, list) or not texts:
|
|
1038
|
+
return {"error": "'texts' must be a non-empty array"}
|
|
1039
|
+
|
|
1040
|
+
if len(texts) > 100:
|
|
1041
|
+
return {"error": "Batch size exceeds limit (max 100)"}
|
|
1042
|
+
|
|
1043
|
+
try:
|
|
1044
|
+
embeddings = await _provider.embed(texts)
|
|
1045
|
+
return {
|
|
1046
|
+
"embeddings": embeddings,
|
|
1047
|
+
"dimensions": _provider.dimensions(),
|
|
1048
|
+
}
|
|
1049
|
+
except Exception as e:
|
|
1050
|
+
return {"error": str(e)}
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
@registry.tool(
|
|
1054
|
+
"index",
|
|
1055
|
+
"Index text items for vector similarity search. Each item gets embedded and stored persistently.",
|
|
1056
|
+
{
|
|
1057
|
+
"type": "object",
|
|
1058
|
+
"properties": {
|
|
1059
|
+
"namespace": {
|
|
1060
|
+
"type": "string",
|
|
1061
|
+
"description": "Namespace for isolation (e.g., 'cpersona:agent-id')",
|
|
1062
|
+
"default": "default",
|
|
1063
|
+
},
|
|
1064
|
+
"items": {
|
|
1065
|
+
"type": "array",
|
|
1066
|
+
"items": {
|
|
1067
|
+
"type": "object",
|
|
1068
|
+
"properties": {
|
|
1069
|
+
"id": {"type": "string"},
|
|
1070
|
+
"text": {"type": "string"},
|
|
1071
|
+
},
|
|
1072
|
+
"required": ["id", "text"],
|
|
1073
|
+
},
|
|
1074
|
+
"description": "Items to index (max 100)",
|
|
1075
|
+
},
|
|
1076
|
+
},
|
|
1077
|
+
"required": ["items"],
|
|
1078
|
+
},
|
|
1079
|
+
)
|
|
1080
|
+
async def handle_index_tool(arguments: dict) -> dict:
|
|
1081
|
+
if _provider is None or _vector_index is None:
|
|
1082
|
+
return {"error": "Not initialized or index disabled"}
|
|
1083
|
+
|
|
1084
|
+
namespace = arguments.get("namespace", "default")
|
|
1085
|
+
items = arguments.get("items", [])
|
|
1086
|
+
if not items or len(items) > 100:
|
|
1087
|
+
return {"error": "items must be 1-100 entries"}
|
|
1088
|
+
|
|
1089
|
+
try:
|
|
1090
|
+
indexed = await _vector_index.index(namespace, items, _provider)
|
|
1091
|
+
return {"ok": True, "indexed": indexed}
|
|
1092
|
+
except Exception as e:
|
|
1093
|
+
return {"error": str(e)}
|
|
1094
|
+
|
|
1095
|
+
|
|
1096
|
+
@registry.tool(
|
|
1097
|
+
"search",
|
|
1098
|
+
"Search indexed vectors by semantic similarity. Returns top-K results with scores.",
|
|
1099
|
+
{
|
|
1100
|
+
"type": "object",
|
|
1101
|
+
"properties": {
|
|
1102
|
+
"namespace": {
|
|
1103
|
+
"type": "string",
|
|
1104
|
+
"description": "Namespace to search within",
|
|
1105
|
+
"default": "default",
|
|
1106
|
+
},
|
|
1107
|
+
"query": {
|
|
1108
|
+
"type": "string",
|
|
1109
|
+
"description": "Search query text",
|
|
1110
|
+
},
|
|
1111
|
+
"limit": {
|
|
1112
|
+
"type": "integer",
|
|
1113
|
+
"description": "Max results to return (default: 10)",
|
|
1114
|
+
"default": 10,
|
|
1115
|
+
},
|
|
1116
|
+
"min_similarity": {
|
|
1117
|
+
"type": "number",
|
|
1118
|
+
"description": "Minimum cosine similarity threshold (default: 0.3)",
|
|
1119
|
+
"default": 0.3,
|
|
1120
|
+
},
|
|
1121
|
+
},
|
|
1122
|
+
"required": ["query"],
|
|
1123
|
+
},
|
|
1124
|
+
)
|
|
1125
|
+
async def handle_search_tool(arguments: dict) -> dict:
|
|
1126
|
+
if _provider is None or _vector_index is None:
|
|
1127
|
+
return {"error": "Not initialized or index disabled"}
|
|
1128
|
+
|
|
1129
|
+
namespace = arguments.get("namespace", "default")
|
|
1130
|
+
query = arguments.get("query", "")
|
|
1131
|
+
limit = min(int(arguments.get("limit", 10)), 500)
|
|
1132
|
+
min_similarity = float(arguments.get("min_similarity", 0.3))
|
|
1133
|
+
|
|
1134
|
+
if not query.strip():
|
|
1135
|
+
return {"error": "query must be non-empty"}
|
|
1136
|
+
|
|
1137
|
+
try:
|
|
1138
|
+
results = await _vector_index.search(namespace, query, limit, min_similarity, _provider)
|
|
1139
|
+
return {"results": results}
|
|
1140
|
+
except Exception as e:
|
|
1141
|
+
return {"error": str(e)}
|
|
1142
|
+
|
|
1143
|
+
|
|
1144
|
+
@registry.tool(
|
|
1145
|
+
"remove",
|
|
1146
|
+
"Remove items from the vector index by ID.",
|
|
1147
|
+
{
|
|
1148
|
+
"type": "object",
|
|
1149
|
+
"properties": {
|
|
1150
|
+
"namespace": {
|
|
1151
|
+
"type": "string",
|
|
1152
|
+
"description": "Namespace containing the items",
|
|
1153
|
+
"default": "default",
|
|
1154
|
+
},
|
|
1155
|
+
"ids": {
|
|
1156
|
+
"type": "array",
|
|
1157
|
+
"items": {"type": "string"},
|
|
1158
|
+
"description": "Item IDs to remove",
|
|
1159
|
+
},
|
|
1160
|
+
},
|
|
1161
|
+
"required": ["ids"],
|
|
1162
|
+
},
|
|
1163
|
+
)
|
|
1164
|
+
async def handle_remove_tool(arguments: dict) -> dict:
|
|
1165
|
+
if _vector_index is None:
|
|
1166
|
+
return {"error": "Index not initialized or disabled"}
|
|
1167
|
+
|
|
1168
|
+
namespace = arguments.get("namespace", "default")
|
|
1169
|
+
ids = arguments.get("ids", [])
|
|
1170
|
+
if not ids:
|
|
1171
|
+
return {"error": "ids must be non-empty"}
|
|
1172
|
+
|
|
1173
|
+
try:
|
|
1174
|
+
removed = await _vector_index.remove(namespace, ids)
|
|
1175
|
+
return {"ok": True, "removed": removed}
|
|
1176
|
+
except Exception as e:
|
|
1177
|
+
return {"error": str(e)}
|
|
1178
|
+
|
|
1179
|
+
|
|
1180
|
+
@registry.tool(
|
|
1181
|
+
"purge",
|
|
1182
|
+
"Remove ALL vectors in a namespace. Use for bulk cleanup (e.g., agent deletion).",
|
|
1183
|
+
{
|
|
1184
|
+
"type": "object",
|
|
1185
|
+
"properties": {
|
|
1186
|
+
"namespace": {
|
|
1187
|
+
"type": "string",
|
|
1188
|
+
"description": "Namespace to purge completely",
|
|
1189
|
+
},
|
|
1190
|
+
},
|
|
1191
|
+
"required": ["namespace"],
|
|
1192
|
+
},
|
|
1193
|
+
)
|
|
1194
|
+
async def handle_purge_tool(arguments: dict) -> dict:
|
|
1195
|
+
if _vector_index is None:
|
|
1196
|
+
return {"error": "Index not initialized or disabled"}
|
|
1197
|
+
|
|
1198
|
+
namespace = arguments.get("namespace", "")
|
|
1199
|
+
if not namespace:
|
|
1200
|
+
return {"error": "namespace is required"}
|
|
1201
|
+
|
|
1202
|
+
try:
|
|
1203
|
+
removed = await _vector_index.purge_namespace(namespace)
|
|
1204
|
+
return {"ok": True, "removed": removed}
|
|
1205
|
+
except Exception as e:
|
|
1206
|
+
return {"error": str(e)}
|
|
1207
|
+
|
|
1208
|
+
|
|
1209
|
+
# ============================================================
|
|
1210
|
+
# Main
|
|
1211
|
+
# ============================================================
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
async def _run_streamable_http() -> None:
|
|
1215
|
+
"""Run embedding as a Streamable HTTP MCP server (no stdio, no REST /embed).
|
|
1216
|
+
|
|
1217
|
+
Enabled by setting EMBEDDING_TRANSPORT=streamable-http. Listens on
|
|
1218
|
+
EMBEDDING_MCP_HTTP_PORT (default 8403) and mounts the MCP endpoint at
|
|
1219
|
+
/embedding/mcp and /embedding so it can coexist with other services
|
|
1220
|
+
behind path-based reverse proxies.
|
|
1221
|
+
"""
|
|
1222
|
+
global _provider, _vector_index
|
|
1223
|
+
|
|
1224
|
+
import contextlib
|
|
1225
|
+
from collections.abc import AsyncIterator
|
|
1226
|
+
|
|
1227
|
+
import uvicorn
|
|
1228
|
+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
|
1229
|
+
from starlette.applications import Starlette
|
|
1230
|
+
from starlette.middleware import Middleware
|
|
1231
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
1232
|
+
from starlette.routing import Mount
|
|
1233
|
+
|
|
1234
|
+
_provider = create_provider()
|
|
1235
|
+
await _provider.initialize()
|
|
1236
|
+
|
|
1237
|
+
if EMBEDDING_INDEX_ENABLED:
|
|
1238
|
+
_vector_index = VectorIndex(EMBEDDING_INDEX_DB_PATH)
|
|
1239
|
+
await _vector_index.initialize()
|
|
1240
|
+
|
|
1241
|
+
session_manager = StreamableHTTPSessionManager(
|
|
1242
|
+
app=registry.server,
|
|
1243
|
+
stateless=True,
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
async def mcp_endpoint(scope, receive, send):
|
|
1247
|
+
await session_manager.handle_request(scope, receive, send)
|
|
1248
|
+
|
|
1249
|
+
@contextlib.asynccontextmanager
|
|
1250
|
+
async def lifespan(_app: Starlette) -> AsyncIterator[None]:
|
|
1251
|
+
async with session_manager.run():
|
|
1252
|
+
logger.info("Embedding Streamable HTTP server ready")
|
|
1253
|
+
yield
|
|
1254
|
+
|
|
1255
|
+
app = Starlette(
|
|
1256
|
+
routes=[
|
|
1257
|
+
Mount("/embedding/mcp", app=mcp_endpoint),
|
|
1258
|
+
Mount("/embedding", app=mcp_endpoint),
|
|
1259
|
+
Mount("/mcp", app=mcp_endpoint),
|
|
1260
|
+
Mount("/", app=mcp_endpoint),
|
|
1261
|
+
],
|
|
1262
|
+
middleware=[
|
|
1263
|
+
Middleware(
|
|
1264
|
+
CORSMiddleware,
|
|
1265
|
+
allow_origins=["https://claude.ai", "https://www.claude.ai"],
|
|
1266
|
+
allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
|
|
1267
|
+
allow_headers=[
|
|
1268
|
+
"Authorization",
|
|
1269
|
+
"Content-Type",
|
|
1270
|
+
"Mcp-Session-Id",
|
|
1271
|
+
"Mcp-Protocol-Version",
|
|
1272
|
+
"Last-Event-Id",
|
|
1273
|
+
],
|
|
1274
|
+
expose_headers=["Mcp-Session-Id"],
|
|
1275
|
+
),
|
|
1276
|
+
],
|
|
1277
|
+
lifespan=lifespan,
|
|
1278
|
+
)
|
|
1279
|
+
|
|
1280
|
+
host = os.environ.get("EMBEDDING_MCP_HTTP_HOST", "0.0.0.0")
|
|
1281
|
+
port = int(os.environ.get("EMBEDDING_MCP_HTTP_PORT", "8403"))
|
|
1282
|
+
logger.info("Starting Embedding Streamable HTTP MCP on %s:%d", host, port)
|
|
1283
|
+
|
|
1284
|
+
config = uvicorn.Config(app, host=host, port=port, log_level="info")
|
|
1285
|
+
server = uvicorn.Server(config)
|
|
1286
|
+
try:
|
|
1287
|
+
await server.serve()
|
|
1288
|
+
finally:
|
|
1289
|
+
if _vector_index:
|
|
1290
|
+
await _vector_index.shutdown()
|
|
1291
|
+
await _provider.shutdown()
|
|
1292
|
+
logger.info("Embedding Streamable HTTP server shut down")
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
async def main():
|
|
1296
|
+
global _provider, _vector_index
|
|
1297
|
+
|
|
1298
|
+
logging.basicConfig(
|
|
1299
|
+
level=logging.INFO,
|
|
1300
|
+
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
|
1301
|
+
)
|
|
1302
|
+
|
|
1303
|
+
transport = os.environ.get("EMBEDDING_TRANSPORT", "stdio")
|
|
1304
|
+
if transport == "streamable-http":
|
|
1305
|
+
await _run_streamable_http()
|
|
1306
|
+
return
|
|
1307
|
+
|
|
1308
|
+
logger.info(
|
|
1309
|
+
"Starting embedding server (provider=%s, http_port=%d, index=%s)",
|
|
1310
|
+
EMBEDDING_PROVIDER,
|
|
1311
|
+
EMBEDDING_HTTP_PORT,
|
|
1312
|
+
"enabled" if EMBEDDING_INDEX_ENABLED else "disabled",
|
|
1313
|
+
)
|
|
1314
|
+
|
|
1315
|
+
_provider = create_provider()
|
|
1316
|
+
await _provider.initialize()
|
|
1317
|
+
|
|
1318
|
+
# Initialize vector index if enabled
|
|
1319
|
+
if EMBEDDING_INDEX_ENABLED:
|
|
1320
|
+
_vector_index = VectorIndex(EMBEDDING_INDEX_DB_PATH)
|
|
1321
|
+
await _vector_index.initialize()
|
|
1322
|
+
|
|
1323
|
+
# Start HTTP endpoint as background task
|
|
1324
|
+
http_task = asyncio.create_task(run_http_server(EMBEDDING_HTTP_PORT))
|
|
1325
|
+
|
|
1326
|
+
try:
|
|
1327
|
+
async with stdio_server() as (read_stream, write_stream):
|
|
1328
|
+
await registry.server.run(
|
|
1329
|
+
read_stream,
|
|
1330
|
+
write_stream,
|
|
1331
|
+
registry.server.create_initialization_options(),
|
|
1332
|
+
)
|
|
1333
|
+
finally:
|
|
1334
|
+
http_task.cancel()
|
|
1335
|
+
try:
|
|
1336
|
+
await http_task
|
|
1337
|
+
except asyncio.CancelledError:
|
|
1338
|
+
pass
|
|
1339
|
+
if _vector_index:
|
|
1340
|
+
await _vector_index.shutdown()
|
|
1341
|
+
await _provider.shutdown()
|
|
1342
|
+
logger.info("Embedding server shut down")
|
|
1343
|
+
|
|
1344
|
+
|
|
1345
|
+
def run():
|
|
1346
|
+
"""Console-script / ``python -m cembedding`` entry point (sync wrapper)."""
|
|
1347
|
+
asyncio.run(main())
|
|
1348
|
+
|
|
1349
|
+
|
|
1350
|
+
if __name__ == "__main__":
|
|
1351
|
+
run()
|