cembedding 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cembedding/server.py ADDED
@@ -0,0 +1,1351 @@
1
+ """
2
+ Cloto MCP Server: Vector Embedding
3
+ Pluggable embedding provider with HTTP endpoint for inter-server communication.
4
+ Providers: api_openai (OpenAI-compatible API), onnx_miniml (local MiniLM ONNX).
5
+
6
+ v0.2.0: Vector Index — persistent index + search endpoints for centralized vector search.
7
+
8
+ Design: docs/CPERSONA_MEMORY_DESIGN.md Section 5
9
+ """
10
+
11
+ import asyncio
12
+ import logging
13
+ import os
14
+ import platform as _platform
15
+ import struct
16
+ import sys
17
+ from abc import ABC, abstractmethod
18
+
19
+ import httpx
20
+ import numpy as np
21
+ from aiohttp import web
22
+ from mcp.server.stdio import stdio_server
23
+
24
+
25
+ from cembedding._vendored_mcp_common.mcp_utils import ToolRegistry
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # ============================================================
30
+ # Configuration
31
+ # ============================================================
32
+
33
+ EMBEDDING_PROVIDER = os.environ.get("EMBEDDING_PROVIDER", "api_openai")
34
+ EMBEDDING_HTTP_PORT = int(os.environ.get("EMBEDDING_HTTP_PORT", "8401"))
35
+ if not (1 <= EMBEDDING_HTTP_PORT <= 65535):
36
+ raise ValueError(f"EMBEDDING_HTTP_PORT must be 1-65535, got {EMBEDDING_HTTP_PORT}")
37
+ EMBEDDING_API_KEY = os.environ.get("EMBEDDING_API_KEY", "")
38
+ EMBEDDING_API_URL = os.environ.get("EMBEDDING_API_URL", "https://api.openai.com/v1/embeddings")
39
+ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "") # provider-dependent default
40
+ EMBEDDING_TIMEOUT = int(os.environ.get("EMBEDDING_TIMEOUT_SECS", "30"))
41
+
42
+ # Vector Index (v0.2.0)
43
+ EMBEDDING_INDEX_ENABLED = os.environ.get("EMBEDDING_INDEX_ENABLED", "true").lower() == "true"
44
+ EMBEDDING_INDEX_DB_PATH = os.environ.get("EMBEDDING_INDEX_DB_PATH", "data/embedding_index.db")
45
+
46
+ # ONNX tokenization max sequence length (1-8192). MiniLM is clamped to 512
47
+ # internally (positional embeddings cap). Jina-v5-nano supports up to 8192.
48
+ ONNX_MAX_SEQ_LEN = int(os.environ.get("ONNX_MAX_SEQ_LEN", "2048"))
49
+ if not (1 <= ONNX_MAX_SEQ_LEN <= 8192):
50
+ raise ValueError(f"ONNX_MAX_SEQ_LEN must be 1-8192, got {ONNX_MAX_SEQ_LEN}")
51
+
52
+ # ONNX-specific — resolve relative paths against CLOTO_PROJECT_DIR when running
53
+ # inside a sandbox (isolation changes the working directory).
54
+ _project_dir = os.environ.get("CLOTO_PROJECT_DIR", "")
55
+ _MODEL_DIRS = {
56
+ "onnx_miniml": "data/models/all-MiniLM-L6-v2",
57
+ "onnx_jina_v5_nano": "data/models/jina-embeddings-v5-text-nano",
58
+ "onnx_bge_m3": "data/models/bge-m3",
59
+ "mlx_bge_m3": "data/models/bge-m3-mlx",
60
+ }
61
+ _default_model_dir = _MODEL_DIRS.get(EMBEDDING_PROVIDER, "data/models/all-MiniLM-L6-v2")
62
+ ONNX_MODEL_DIR = os.environ.get("ONNX_MODEL_DIR", "")
63
+ if not ONNX_MODEL_DIR:
64
+ if _project_dir and not os.path.isabs(_default_model_dir):
65
+ ONNX_MODEL_DIR = os.path.join(_project_dir, _default_model_dir)
66
+ else:
67
+ ONNX_MODEL_DIR = _default_model_dir
68
+
69
+
70
+ def _select_ort_providers() -> list:
71
+ """Select ONNX Runtime execution providers with cross-platform fallback.
72
+
73
+ Priority when ONNX_EP_PREFERENCE is empty (auto-detect):
74
+ 1. CoreMLExecutionProvider (macOS — standard onnxruntime)
75
+ 2. DmlExecutionProvider (Windows — standard onnxruntime)
76
+ 3. CUDAExecutionProvider (Linux/Windows NVIDIA — requires onnxruntime-gpu)
77
+ 4. ROCmExecutionProvider (Linux AMD — requires onnxruntime ROCm build)
78
+ 5. CPUExecutionProvider (always appended as terminal fallback)
79
+
80
+ When ONNX_EP_PREFERENCE is set, use the comma-separated list but always
81
+ filter against get_available_providers() and always ensure CPUExecutionProvider
82
+ is present so session creation cannot fail for lack of any provider.
83
+
84
+ Fail-open: if onnxruntime import or get_available_providers() raises,
85
+ return ["CPUExecutionProvider"] so the caller can still attempt to load.
86
+ """
87
+ try:
88
+ import onnxruntime as ort
89
+
90
+ available = set(ort.get_available_providers())
91
+ except Exception:
92
+ return ["CPUExecutionProvider"]
93
+
94
+ preference = os.environ.get("ONNX_EP_PREFERENCE", "").strip()
95
+
96
+ if preference:
97
+ requested = [p.strip() for p in preference.split(",") if p.strip()]
98
+ providers = [p for p in requested if p in available]
99
+ else:
100
+ providers = []
101
+ for candidate in (
102
+ "CoreMLExecutionProvider", # macOS
103
+ "DmlExecutionProvider", # Windows (DirectML)
104
+ "CUDAExecutionProvider", # Linux/Windows NVIDIA (onnxruntime-gpu)
105
+ "ROCmExecutionProvider", # Linux AMD (onnxruntime ROCm)
106
+ ):
107
+ if candidate in available:
108
+ providers.append(candidate)
109
+
110
+ if "CPUExecutionProvider" not in providers:
111
+ providers.append("CPUExecutionProvider")
112
+
113
+ return providers
114
+
115
+
116
+ # ============================================================
117
+ # Provider Abstraction
118
+ # ============================================================
119
+
120
+
121
+ class EmbeddingProvider(ABC):
122
+ """Abstract base class for embedding providers."""
123
+
124
+ @abstractmethod
125
+ async def initialize(self) -> None:
126
+ """Initialize the provider (load model, create client, etc.)."""
127
+
128
+ @abstractmethod
129
+ async def embed(self, texts: list[str]) -> list[list[float]]:
130
+ """Generate embeddings for a batch of texts."""
131
+
132
+ @abstractmethod
133
+ def dimensions(self) -> int:
134
+ """Return the embedding dimensionality."""
135
+
136
+ async def shutdown(self) -> None:
137
+ """Clean up resources."""
138
+
139
+
140
+ # ============================================================
141
+ # api_openai Provider
142
+ # ============================================================
143
+
144
+
145
+ class OpenAIEmbeddingProvider(EmbeddingProvider):
146
+ """OpenAI-compatible embedding API provider."""
147
+
148
+ def __init__(self, api_key: str, api_url: str, model: str, timeout: int):
149
+ self._api_key = api_key
150
+ self._api_url = api_url
151
+ self._model = model or "text-embedding-3-small"
152
+ self._timeout = timeout
153
+ self._client: httpx.AsyncClient | None = None
154
+ self._dimensions = int(os.environ.get("EMBEDDING_DIMENSIONS", "1536"))
155
+
156
+ async def initialize(self) -> None:
157
+ if not self._api_key:
158
+ raise ValueError("EMBEDDING_API_KEY is required for api_openai provider")
159
+ self._client = httpx.AsyncClient(timeout=self._timeout)
160
+ logger.info(
161
+ "OpenAI embedding provider initialized (model=%s, url=%s)",
162
+ self._model,
163
+ self._api_url,
164
+ )
165
+
166
+ async def embed(self, texts: list[str]) -> list[list[float]]:
167
+ if not self._client:
168
+ raise RuntimeError("Provider not initialized")
169
+
170
+ response = await self._client.post(
171
+ self._api_url,
172
+ headers={
173
+ "Authorization": f"Bearer {self._api_key}",
174
+ "Content-Type": "application/json",
175
+ },
176
+ json={"model": self._model, "input": texts},
177
+ )
178
+ response.raise_for_status()
179
+
180
+ data = response.json()
181
+ embeddings = [item["embedding"] for item in data["data"]]
182
+
183
+ # Update dimensions from actual response
184
+ if embeddings:
185
+ self._dimensions = len(embeddings[0])
186
+
187
+ # L2-normalize for consistent cosine similarity via dot product
188
+ result = []
189
+ for emb in embeddings:
190
+ vec = np.array(emb, dtype=np.float32)
191
+ norm = np.linalg.norm(vec)
192
+ if norm > 1e-9:
193
+ vec = vec / norm
194
+ result.append(vec.tolist())
195
+
196
+ return result
197
+
198
+ def dimensions(self) -> int:
199
+ return self._dimensions
200
+
201
+ async def shutdown(self) -> None:
202
+ if self._client:
203
+ await self._client.aclose()
204
+ self._client = None
205
+
206
+
207
+ # ============================================================
208
+ # onnx_miniml Provider
209
+ # ============================================================
210
+
211
+
212
+ class OnnxMiniLMProvider(EmbeddingProvider):
213
+ """Local all-MiniLM-L6-v2 ONNX embedding provider."""
214
+
215
+ def __init__(self, model_dir: str):
216
+ self._model_dir = model_dir
217
+ self._session = None
218
+ self._tokenizer = None
219
+ self._lock = asyncio.Lock()
220
+
221
+ async def initialize(self) -> None:
222
+ try:
223
+ import onnxruntime as ort
224
+ from tokenizers import Tokenizer
225
+ except ImportError:
226
+ raise ImportError(
227
+ "onnx_miniml provider requires: uv pip install onnxruntime tokenizers\n"
228
+ "Or: uv pip install cloto-mcp-embedding[onnx]"
229
+ )
230
+
231
+ model_path = os.path.join(self._model_dir, "model.onnx")
232
+ tokenizer_path = os.path.join(self._model_dir, "tokenizer.json")
233
+
234
+ # Auto-download model if missing
235
+ if not os.path.exists(model_path) or not os.path.exists(tokenizer_path):
236
+ logger.info("ONNX model not found, downloading automatically...")
237
+ try:
238
+ from cembedding.download_model import download
239
+
240
+ if not download():
241
+ raise FileNotFoundError(f"Failed to download ONNX model to {self._model_dir}")
242
+ except ImportError:
243
+ raise FileNotFoundError(
244
+ f"ONNX model not found at {model_path}. "
245
+ f"Download with: python -m cembedding.download_model"
246
+ )
247
+
248
+ providers = _select_ort_providers()
249
+
250
+ self._session = ort.InferenceSession(model_path, providers=providers)
251
+ self._tokenizer = Tokenizer.from_file(tokenizer_path)
252
+ # MiniLM positional embeddings cap at 512 — clamp ONNX_MAX_SEQ_LEN.
253
+ miniml_seq_len = min(ONNX_MAX_SEQ_LEN, 512)
254
+ if miniml_seq_len < ONNX_MAX_SEQ_LEN:
255
+ logger.warning(
256
+ "MiniLM max_position=512, clamping ONNX_MAX_SEQ_LEN=%d to 512",
257
+ ONNX_MAX_SEQ_LEN,
258
+ )
259
+ self._tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=miniml_seq_len)
260
+ self._tokenizer.enable_truncation(max_length=miniml_seq_len)
261
+
262
+ logger.info(
263
+ "ONNX MiniLM provider initialized (dir=%s, seq_len=%d, requested=%s, active=%s)",
264
+ self._model_dir,
265
+ miniml_seq_len,
266
+ providers,
267
+ self._session.get_providers(),
268
+ )
269
+
270
+ async def embed(self, texts: list[str]) -> list[list[float]]:
271
+ if not self._session or not self._tokenizer:
272
+ raise RuntimeError("Provider not initialized")
273
+
274
+ async with self._lock:
275
+ return await asyncio.get_event_loop().run_in_executor(None, self._embed_sync, texts)
276
+
277
+ def _embed_sync(self, texts: list[str]) -> list[list[float]]:
278
+ """Synchronous embedding (run in executor to avoid blocking)."""
279
+ encodings = self._tokenizer.encode_batch(texts)
280
+
281
+ input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
282
+ attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
283
+ token_type_ids = np.array([e.type_ids for e in encodings], dtype=np.int64)
284
+
285
+ outputs = self._session.run(
286
+ None,
287
+ {
288
+ "input_ids": input_ids,
289
+ "attention_mask": attention_mask,
290
+ "token_type_ids": token_type_ids,
291
+ },
292
+ )
293
+ token_embeddings = outputs[0] # (batch, seq_len, hidden_dim)
294
+
295
+ # Mean pooling + L2 normalization
296
+ mask_expanded = np.expand_dims(attention_mask, -1).astype(np.float32)
297
+ sum_embeddings = np.sum(token_embeddings * mask_expanded, axis=1)
298
+ sum_mask = np.clip(np.sum(mask_expanded, axis=1), a_min=1e-9, a_max=None)
299
+ mean_pooled = sum_embeddings / sum_mask
300
+
301
+ norms = np.linalg.norm(mean_pooled, axis=1, keepdims=True)
302
+ norms = np.clip(norms, a_min=1e-9, a_max=None)
303
+ normalized = mean_pooled / norms
304
+
305
+ return normalized.tolist()
306
+
307
+ def dimensions(self) -> int:
308
+ return 384
309
+
310
+ async def shutdown(self) -> None:
311
+ self._session = None
312
+ self._tokenizer = None
313
+
314
+
315
+ # ============================================================
316
+ # Platform detection helpers
317
+ # ============================================================
318
+
319
+
320
+ def _is_apple_silicon() -> bool:
321
+ return sys.platform == "darwin" and _platform.machine() == "arm64"
322
+
323
+
324
+ def _mlx_available() -> bool:
325
+ try:
326
+ import mlx.core # noqa: F401
327
+ import mlx_embeddings # noqa: F401
328
+
329
+ return True
330
+ except ImportError:
331
+ return False
332
+
333
+
334
+ # ============================================================
335
+ # CoreML-aware session factory (shared by ONNX providers)
336
+ # ============================================================
337
+
338
+
339
+ def _create_ort_session(model_path: str, providers: list):
340
+ """Create InferenceSession with CoreML-aware 3-stage fallback.
341
+
342
+ Stage 1: CoreML with MLProgram + dynamic shapes (ort 1.18+).
343
+ Stage 2: CoreML without provider_options (ort version mismatch tolerance).
344
+ Stage 3: CPU only (guaranteed fallback).
345
+ """
346
+ import onnxruntime as ort
347
+
348
+ if "CoreMLExecutionProvider" in providers:
349
+ rest = [p for p in providers if p != "CoreMLExecutionProvider"]
350
+ providers_with_opts = [
351
+ (
352
+ "CoreMLExecutionProvider",
353
+ {
354
+ "ModelFormat": "MLProgram",
355
+ "MLComputeUnits": "ALL",
356
+ "RequireStaticInputShapes": "0",
357
+ },
358
+ ),
359
+ *rest,
360
+ ]
361
+ try:
362
+ return ort.InferenceSession(model_path, providers=providers_with_opts)
363
+ except Exception as e:
364
+ logger.warning("CoreML init with provider_options failed, retrying without options: %s", e)
365
+ try:
366
+ return ort.InferenceSession(model_path, providers=providers)
367
+ except Exception as e:
368
+ logger.warning("CoreML plain init failed, falling back to CPU-only: %s", e)
369
+ return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
370
+
371
+ return ort.InferenceSession(model_path, providers=providers)
372
+
373
+
374
+ # ============================================================
375
+ # mlx_bge_m3 Provider (macOS Apple Silicon only)
376
+ # ============================================================
377
+
378
+ MLX_BGE_M3_REPO = "mlx-community/bge-m3-mlx-fp16"
379
+
380
+
381
+ class MlxBgeM3Provider(EmbeddingProvider):
382
+ """Apple Silicon MLX bge-m3 provider (fp16, Metal/ANE accelerated).
383
+
384
+ Uses mlx-embeddings to run BAAI/bge-m3 natively on Metal GPU.
385
+ macOS ARM only — auto_bge_m3 falls back to OnnxBgeM3Provider on other platforms.
386
+ """
387
+
388
+ def __init__(self, model_path: str = MLX_BGE_M3_REPO):
389
+ # model_path: HuggingFace repo ID or absolute local directory
390
+ self._model_path = model_path
391
+ self._model = None
392
+ self._tokenizer = None
393
+ self._lock = asyncio.Lock()
394
+
395
+ async def initialize(self) -> None:
396
+ try:
397
+ from mlx_embeddings.utils import load as mlx_load
398
+ except ImportError:
399
+ raise ImportError("mlx_bge_m3 requires: uv pip install mlx-embeddings")
400
+
401
+ self._model, self._tokenizer = await asyncio.get_event_loop().run_in_executor(None, mlx_load, self._model_path)
402
+ logger.info("MLX BGE-M3 provider initialized (path=%s)", self._model_path)
403
+
404
+ async def embed(self, texts: list[str]) -> list[list[float]]:
405
+ if not self._model or not self._tokenizer:
406
+ raise RuntimeError("Provider not initialized")
407
+
408
+ async with self._lock:
409
+ return await asyncio.get_event_loop().run_in_executor(None, self._embed_sync, texts)
410
+
411
+ def _embed_sync(self, texts: list[str]) -> list[list[float]]:
412
+ import mlx.core as mx
413
+
414
+ inputs = self._tokenizer.batch_encode_plus(
415
+ texts,
416
+ return_tensors="mlx",
417
+ padding=True,
418
+ truncation=True,
419
+ max_length=min(ONNX_MAX_SEQ_LEN, 8192),
420
+ )
421
+ outputs = self._model(
422
+ inputs["input_ids"],
423
+ attention_mask=inputs.get("attention_mask"),
424
+ )
425
+ embs = outputs.text_embeds # already L2-normalized
426
+ mx.eval(embs)
427
+ return embs.tolist()
428
+
429
+ def dimensions(self) -> int:
430
+ return 1024
431
+
432
+ async def shutdown(self) -> None:
433
+ self._model = None
434
+ self._tokenizer = None
435
+
436
+
437
+ # ============================================================
438
+ # onnx_bge_m3 Provider
439
+ # ============================================================
440
+
441
+
442
+ class OnnxBgeM3Provider(EmbeddingProvider):
443
+ """Local BAAI/bge-m3 ONNX embedding provider (int8 quantized).
444
+
445
+ CLS-token pooling. 1024-dim dense output, 8K context, 100+ languages.
446
+ Based on XLM-RoBERTa — no token_type_ids.
447
+ """
448
+
449
+ def __init__(self, model_dir: str):
450
+ self._model_dir = model_dir
451
+ self._session = None
452
+ self._tokenizer = None
453
+ self._lock = asyncio.Lock()
454
+
455
+ async def initialize(self) -> None:
456
+ try:
457
+ import onnxruntime # noqa: F401
458
+ from tokenizers import Tokenizer
459
+ except ImportError:
460
+ raise ImportError("onnx_bge_m3 provider requires: uv pip install onnxruntime tokenizers")
461
+
462
+ model_dir_abs = os.path.abspath(self._model_dir)
463
+ model_path = os.path.join(model_dir_abs, "model.onnx")
464
+ tokenizer_path = os.path.join(model_dir_abs, "tokenizer.json")
465
+
466
+ if not os.path.exists(model_path) or not os.path.exists(tokenizer_path):
467
+ logger.info("BGE-M3 ONNX model not found, downloading...")
468
+ try:
469
+ from cembedding.download_model import download_bge_m3
470
+
471
+ if not download_bge_m3(model_dir_abs):
472
+ raise FileNotFoundError(f"Failed to download BGE-M3 model to {model_dir_abs}")
473
+ except ImportError:
474
+ raise FileNotFoundError(
475
+ f"ONNX model not found at {model_path}. "
476
+ f"Download with: python -m cembedding.download_model --model bge-m3"
477
+ )
478
+
479
+ providers = _select_ort_providers()
480
+ self._session = _create_ort_session(model_path, providers)
481
+ self._tokenizer = Tokenizer.from_file(tokenizer_path)
482
+ bge_seq_len = min(ONNX_MAX_SEQ_LEN, 8192)
483
+ # XLM-RoBERTa pad token is <pad> (id=1)
484
+ self._tokenizer.enable_padding(pad_id=1, pad_token="<pad>", length=bge_seq_len)
485
+ self._tokenizer.enable_truncation(max_length=bge_seq_len)
486
+
487
+ logger.info(
488
+ "ONNX BGE-M3 provider initialized (dir=%s, seq_len=%d, requested=%s, active=%s)",
489
+ self._model_dir,
490
+ bge_seq_len,
491
+ providers,
492
+ self._session.get_providers(),
493
+ )
494
+
495
+ async def embed(self, texts: list[str]) -> list[list[float]]:
496
+ if not self._session or not self._tokenizer:
497
+ raise RuntimeError("Provider not initialized")
498
+
499
+ async with self._lock:
500
+ return await asyncio.get_event_loop().run_in_executor(None, self._embed_sync, texts)
501
+
502
+ def _embed_sync(self, texts: list[str]) -> list[list[float]]:
503
+ """Synchronous embedding with CLS-token pooling."""
504
+ encodings = self._tokenizer.encode_batch(texts)
505
+
506
+ input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
507
+ attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
508
+
509
+ input_names = [inp.name for inp in self._session.get_inputs()]
510
+ inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
511
+ if "token_type_ids" in input_names:
512
+ inputs["token_type_ids"] = np.zeros_like(input_ids)
513
+
514
+ outputs = self._session.run(None, inputs)
515
+ token_embeddings = outputs[0] # (batch, seq_len, 1024)
516
+
517
+ # CLS-token pooling: first token of each sequence
518
+ cls_embs = token_embeddings[:, 0, :].astype(np.float32, copy=False)
519
+
520
+ # L2 normalization
521
+ norms = np.linalg.norm(cls_embs, axis=1, keepdims=True)
522
+ norms = np.clip(norms, a_min=1e-9, a_max=None)
523
+ normalized = cls_embs / norms
524
+
525
+ return normalized.tolist()
526
+
527
+ def dimensions(self) -> int:
528
+ return 1024
529
+
530
+ async def shutdown(self) -> None:
531
+ self._session = None
532
+ self._tokenizer = None
533
+
534
+
535
+ # ============================================================
536
+ # onnx_jina_v5_nano Provider
537
+ # ============================================================
538
+
539
+
540
+ class OnnxJinaV5NanoProvider(EmbeddingProvider):
541
+ """Local jina-embeddings-v5-text-nano ONNX embedding provider.
542
+
543
+ Uses Last-Token pooling (different from MiniLM's mean pooling).
544
+ 768-dim output, 8K context, retrieval-optimized merged LoRA.
545
+ """
546
+
547
+ def __init__(self, model_dir: str):
548
+ self._model_dir = model_dir
549
+ self._session = None
550
+ self._tokenizer = None
551
+ self._lock = asyncio.Lock()
552
+
553
+ async def initialize(self) -> None:
554
+ try:
555
+ import onnxruntime # noqa: F401 # availability check — session is created in _create_session_with_fallback
556
+ from tokenizers import Tokenizer
557
+ except ImportError:
558
+ raise ImportError(
559
+ "onnx_jina_v5_nano provider requires: uv pip install onnxruntime tokenizers\n"
560
+ "Or: uv pip install cloto-mcp-embedding[onnx]"
561
+ )
562
+
563
+ model_path = os.path.join(self._model_dir, "model.onnx")
564
+ tokenizer_path = os.path.join(self._model_dir, "tokenizer.json")
565
+
566
+ # Auto-download model if missing
567
+ if not os.path.exists(model_path) or not os.path.exists(tokenizer_path):
568
+ logger.info("Jina-v5-nano ONNX model not found, downloading...")
569
+ try:
570
+ from cembedding.download_model import download_jina_v5_nano
571
+
572
+ if not download_jina_v5_nano(self._model_dir):
573
+ raise FileNotFoundError(f"Failed to download model to {self._model_dir}")
574
+ except ImportError:
575
+ raise FileNotFoundError(
576
+ f"ONNX model not found at {model_path}. "
577
+ f"Download with: python -m cembedding.download_model --model jina-v5-nano"
578
+ )
579
+
580
+ providers = _select_ort_providers()
581
+ self._session = self._create_session_with_fallback(model_path, providers)
582
+ self._tokenizer = Tokenizer.from_file(tokenizer_path)
583
+ # jina-v5-nano supports up to 8K context via RoPE.
584
+ self._tokenizer.enable_padding(pad_id=0, pad_token="<pad>", length=ONNX_MAX_SEQ_LEN)
585
+ self._tokenizer.enable_truncation(max_length=ONNX_MAX_SEQ_LEN)
586
+
587
+ logger.info(
588
+ "ONNX Jina-v5-nano provider initialized (dir=%s, seq_len=%d, requested=%s, active=%s)",
589
+ self._model_dir,
590
+ ONNX_MAX_SEQ_LEN,
591
+ providers,
592
+ self._session.get_providers(),
593
+ )
594
+
595
+ def _create_session_with_fallback(self, model_path: str, providers: list):
596
+ return _create_ort_session(model_path, providers)
597
+
598
+ async def embed(self, texts: list[str]) -> list[list[float]]:
599
+ if not self._session or not self._tokenizer:
600
+ raise RuntimeError("Provider not initialized")
601
+
602
+ async with self._lock:
603
+ return await asyncio.get_event_loop().run_in_executor(None, self._embed_sync, texts)
604
+
605
+ def _embed_sync(self, texts: list[str]) -> list[list[float]]:
606
+ """Synchronous embedding with Last-Token pooling."""
607
+ encodings = self._tokenizer.encode_batch(texts)
608
+
609
+ input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
610
+ attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
611
+
612
+ # jina-v5-nano may not use token_type_ids — check model inputs
613
+ input_names = [inp.name for inp in self._session.get_inputs()]
614
+ inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
615
+ if "token_type_ids" in input_names:
616
+ inputs["token_type_ids"] = np.zeros_like(input_ids)
617
+
618
+ outputs = self._session.run(None, inputs)
619
+ token_embeddings = outputs[0] # (batch, seq_len, hidden_dim=768)
620
+
621
+ # Last-Token pooling (vectorized): gather embedding at the last non-padding token per row.
622
+ last_indices = np.maximum(np.sum(attention_mask, axis=1) - 1, 0).astype(np.int64)
623
+ batch_indices = np.arange(token_embeddings.shape[0])
624
+ last_token_embs = token_embeddings[batch_indices, last_indices].astype(np.float32, copy=False)
625
+
626
+ # L2 normalization
627
+ norms = np.linalg.norm(last_token_embs, axis=1, keepdims=True)
628
+ norms = np.clip(norms, a_min=1e-9, a_max=None)
629
+ normalized = last_token_embs / norms
630
+
631
+ return normalized.tolist()
632
+
633
+ def dimensions(self) -> int:
634
+ return 768
635
+
636
+ async def shutdown(self) -> None:
637
+ self._session = None
638
+ self._tokenizer = None
639
+
640
+
641
+ # ============================================================
642
+ # Vector Index (v0.2.0)
643
+ # ============================================================
644
+
645
+
646
+ class VectorIndex:
647
+ """Persistent vector index with in-memory search.
648
+
649
+ Stores vectors in SQLite for durability, loads into memory for fast
650
+ brute-force dot product search. Namespaced to support multiple consumers.
651
+ """
652
+
653
+ def __init__(self, db_path: str):
654
+ self._db_path = db_path
655
+ self._db = None
656
+ # In-memory index: {namespace: {item_id: np.array(float32)}}
657
+ self._index: dict[str, dict[str, np.ndarray]] = {}
658
+
659
+ async def initialize(self) -> None:
660
+ import aiosqlite
661
+
662
+ db_dir = os.path.dirname(self._db_path)
663
+ if db_dir:
664
+ os.makedirs(db_dir, exist_ok=True)
665
+
666
+ self._db = await aiosqlite.connect(self._db_path)
667
+ await self._db.execute("PRAGMA journal_mode=WAL")
668
+ await self._db.execute("PRAGMA synchronous=NORMAL")
669
+ await self._db.executescript(
670
+ """
671
+ CREATE TABLE IF NOT EXISTS vectors (
672
+ namespace TEXT NOT NULL,
673
+ item_id TEXT NOT NULL,
674
+ vector BLOB NOT NULL,
675
+ created_at TEXT DEFAULT (datetime('now')),
676
+ PRIMARY KEY (namespace, item_id)
677
+ );
678
+ CREATE INDEX IF NOT EXISTS idx_vectors_ns ON vectors (namespace);
679
+ """
680
+ )
681
+ await self._db.commit()
682
+
683
+ # Load all vectors into memory
684
+ rows = await self._db.execute_fetchall("SELECT namespace, item_id, vector FROM vectors")
685
+ for ns, item_id, blob in rows:
686
+ if ns not in self._index:
687
+ self._index[ns] = {}
688
+ self._index[ns][item_id] = np.frombuffer(blob, dtype=np.float32).copy()
689
+
690
+ total = sum(len(v) for v in self._index.values())
691
+ logger.info("VectorIndex loaded: %d vectors across %d namespaces", total, len(self._index))
692
+
693
+ async def index(self, namespace: str, items: list[dict], provider: "EmbeddingProvider") -> int:
694
+ """Index items. Each item has 'id' and 'text'. Returns count indexed."""
695
+ if not self._db:
696
+ raise RuntimeError("VectorIndex not initialized")
697
+
698
+ texts = [item["text"] for item in items]
699
+ embeddings = await provider.embed(texts)
700
+
701
+ if namespace not in self._index:
702
+ self._index[namespace] = {}
703
+
704
+ indexed = 0
705
+ for item, emb in zip(items, embeddings):
706
+ item_id = item["id"]
707
+ vec = np.array(emb, dtype=np.float32)
708
+ blob = struct.pack(f"<{len(vec)}f", *vec)
709
+
710
+ await self._db.execute(
711
+ "INSERT OR REPLACE INTO vectors (namespace, item_id, vector) VALUES (?, ?, ?)",
712
+ (namespace, item_id, blob),
713
+ )
714
+ self._index[namespace][item_id] = vec
715
+ indexed += 1
716
+
717
+ await self._db.commit()
718
+ return indexed
719
+
720
+ async def search(
721
+ self,
722
+ namespace: str,
723
+ query: str,
724
+ limit: int,
725
+ min_similarity: float,
726
+ provider: "EmbeddingProvider",
727
+ ) -> list[dict]:
728
+ """Search for similar vectors. Returns [{id, score}, ...] sorted by score desc."""
729
+ ns_index = self._index.get(namespace)
730
+ if not ns_index:
731
+ return []
732
+
733
+ embeddings = await provider.embed([query])
734
+ if not embeddings or not embeddings[0]:
735
+ return []
736
+
737
+ query_vec = np.array(embeddings[0], dtype=np.float32)
738
+ query_dim = len(query_vec)
739
+
740
+ candidates = []
741
+ for item_id, vec in ns_index.items():
742
+ if len(vec) != query_dim:
743
+ continue
744
+ sim = float(np.dot(query_vec, vec))
745
+ if sim >= min_similarity:
746
+ candidates.append((sim, item_id))
747
+
748
+ # Top-K via heap
749
+ import heapq
750
+
751
+ top_k = heapq.nlargest(limit, candidates, key=lambda x: x[0])
752
+ return [{"id": item_id, "score": round(score, 4)} for score, item_id in top_k]
753
+
754
+ async def remove(self, namespace: str, ids: list[str]) -> int:
755
+ """Remove items from index. Returns count removed."""
756
+ if not self._db:
757
+ raise RuntimeError("VectorIndex not initialized")
758
+
759
+ removed = 0
760
+ ns_index = self._index.get(namespace, {})
761
+ for item_id in ids:
762
+ cursor = await self._db.execute(
763
+ "DELETE FROM vectors WHERE namespace = ? AND item_id = ?",
764
+ (namespace, item_id),
765
+ )
766
+ if cursor.rowcount > 0:
767
+ removed += 1
768
+ ns_index.pop(item_id, None)
769
+
770
+ await self._db.commit()
771
+ return removed
772
+
773
+ async def purge_namespace(self, namespace: str) -> int:
774
+ """Remove all vectors in a namespace. Returns count removed."""
775
+ if not self._db:
776
+ raise RuntimeError("VectorIndex not initialized")
777
+
778
+ cursor = await self._db.execute("DELETE FROM vectors WHERE namespace = ?", (namespace,))
779
+ await self._db.commit()
780
+ removed = len(self._index.pop(namespace, {}))
781
+ return max(cursor.rowcount, removed)
782
+
783
+ async def count(self, namespace: str) -> int:
784
+ """Count vectors in a namespace."""
785
+ return len(self._index.get(namespace, {}))
786
+
787
+ async def shutdown(self) -> None:
788
+ if self._db:
789
+ await self._db.close()
790
+ self._db = None
791
+ self._index.clear()
792
+
793
+
794
+ _vector_index: VectorIndex | None = None
795
+
796
+
797
+ # ============================================================
798
+ # Provider Factory
799
+ # ============================================================
800
+
801
+
802
+ def _resolve_model_dir(provider_key: str) -> str:
803
+ """Resolve model directory for a provider key, respecting CLOTO_PROJECT_DIR."""
804
+ rel = _MODEL_DIRS.get(provider_key, "data/models/bge-m3")
805
+ if _project_dir and not os.path.isabs(rel):
806
+ return os.path.join(_project_dir, rel)
807
+ return rel
808
+
809
+
810
+ def create_provider() -> EmbeddingProvider:
811
+ """Create an embedding provider based on configuration."""
812
+ if EMBEDDING_PROVIDER == "api_openai":
813
+ return OpenAIEmbeddingProvider(
814
+ api_key=EMBEDDING_API_KEY,
815
+ api_url=EMBEDDING_API_URL,
816
+ model=EMBEDDING_MODEL,
817
+ timeout=EMBEDDING_TIMEOUT,
818
+ )
819
+ elif EMBEDDING_PROVIDER == "onnx_miniml":
820
+ return OnnxMiniLMProvider(model_dir=ONNX_MODEL_DIR)
821
+ elif EMBEDDING_PROVIDER == "onnx_jina_v5_nano":
822
+ return OnnxJinaV5NanoProvider(model_dir=ONNX_MODEL_DIR)
823
+ elif EMBEDDING_PROVIDER == "onnx_bge_m3":
824
+ return OnnxBgeM3Provider(model_dir=ONNX_MODEL_DIR)
825
+ elif EMBEDDING_PROVIDER == "mlx_bge_m3":
826
+ mlx_path = os.environ.get("MLX_MODEL_DIR", "")
827
+ if not mlx_path:
828
+ local = _resolve_model_dir("mlx_bge_m3")
829
+ mlx_path = local if os.path.isdir(local) else MLX_BGE_M3_REPO
830
+ return MlxBgeM3Provider(model_path=mlx_path)
831
+ elif EMBEDDING_PROVIDER == "auto_bge_m3":
832
+ if _is_apple_silicon() and _mlx_available():
833
+ mlx_path = os.environ.get("MLX_MODEL_DIR", "") or MLX_BGE_M3_REPO
834
+ logger.info("auto_bge_m3: Apple Silicon detected, using MLX provider")
835
+ return MlxBgeM3Provider(model_path=mlx_path)
836
+ logger.info("auto_bge_m3: falling back to ONNX CPU provider")
837
+ return OnnxBgeM3Provider(model_dir=ONNX_MODEL_DIR)
838
+ else:
839
+ raise ValueError(
840
+ f"Unknown embedding provider: {EMBEDDING_PROVIDER}. "
841
+ f"Supported: api_openai, onnx_miniml, onnx_jina_v5_nano, onnx_bge_m3, "
842
+ f"mlx_bge_m3, auto_bge_m3"
843
+ )
844
+
845
+
846
+ # ============================================================
847
+ # HTTP Endpoint (for CPersona inter-server communication)
848
+ # ============================================================
849
+
850
+ _provider: EmbeddingProvider | None = None
851
+
852
+
853
+ async def handle_embed(request: web.Request) -> web.Response:
854
+ """POST /embed — Generate embeddings for input texts."""
855
+ if _provider is None:
856
+ return web.json_response({"error": "Provider not initialized"}, status=503)
857
+
858
+ try:
859
+ body = await request.json()
860
+ except Exception:
861
+ return web.json_response({"error": "Invalid JSON body"}, status=400)
862
+
863
+ texts = body.get("texts")
864
+ if not isinstance(texts, list) or not texts:
865
+ return web.json_response(
866
+ {"error": "'texts' must be a non-empty array of strings"},
867
+ status=400,
868
+ )
869
+
870
+ # Limit batch size to prevent OOM
871
+ if len(texts) > 100:
872
+ return web.json_response({"error": "Batch size exceeds limit (max 100)"}, status=400)
873
+
874
+ try:
875
+ embeddings = await _provider.embed(texts)
876
+ return web.json_response(
877
+ {
878
+ "embeddings": embeddings,
879
+ "dimensions": _provider.dimensions(),
880
+ }
881
+ )
882
+ except Exception as e:
883
+ logger.exception("Embedding failed")
884
+ return web.json_response({"error": f"Embedding failed: {e}"}, status=500)
885
+
886
+
887
+ async def handle_index(request: web.Request) -> web.Response:
888
+ """POST /index — Index vectors for later search."""
889
+ if _provider is None or _vector_index is None:
890
+ return web.json_response({"error": "Not initialized"}, status=503)
891
+
892
+ try:
893
+ body = await request.json()
894
+ except Exception:
895
+ return web.json_response({"error": "Invalid JSON body"}, status=400)
896
+
897
+ namespace = body.get("namespace", "default")
898
+ items = body.get("items")
899
+ if not isinstance(items, list) or not items:
900
+ return web.json_response({"error": "'items' must be a non-empty array"}, status=400)
901
+ if len(items) > 100:
902
+ return web.json_response({"error": "Batch size exceeds limit (max 100)"}, status=400)
903
+
904
+ for item in items:
905
+ if not isinstance(item, dict) or "id" not in item or "text" not in item:
906
+ return web.json_response({"error": "Each item must have 'id' and 'text'"}, status=400)
907
+
908
+ try:
909
+ indexed = await _vector_index.index(namespace, items, _provider)
910
+ return web.json_response({"ok": True, "indexed": indexed})
911
+ except Exception as e:
912
+ logger.exception("Index failed")
913
+ return web.json_response({"error": f"Index failed: {e}"}, status=500)
914
+
915
+
916
+ async def handle_search(request: web.Request) -> web.Response:
917
+ """POST /search — Search indexed vectors by similarity."""
918
+ if _provider is None or _vector_index is None:
919
+ return web.json_response({"error": "Not initialized"}, status=503)
920
+
921
+ try:
922
+ body = await request.json()
923
+ except Exception:
924
+ return web.json_response({"error": "Invalid JSON body"}, status=400)
925
+
926
+ namespace = body.get("namespace", "default")
927
+ query = body.get("query")
928
+ if not isinstance(query, str) or not query.strip():
929
+ return web.json_response({"error": "'query' must be a non-empty string"}, status=400)
930
+
931
+ limit = min(int(body.get("limit", 10)), 500)
932
+ min_similarity = float(body.get("min_similarity", 0.3))
933
+
934
+ try:
935
+ results = await _vector_index.search(namespace, query, limit, min_similarity, _provider)
936
+ return web.json_response({"results": results})
937
+ except Exception as e:
938
+ logger.exception("Search failed")
939
+ return web.json_response({"error": f"Search failed: {e}"}, status=500)
940
+
941
+
942
+ async def handle_remove(request: web.Request) -> web.Response:
943
+ """POST /remove — Remove vectors from index."""
944
+ if _vector_index is None:
945
+ return web.json_response({"error": "Not initialized"}, status=503)
946
+
947
+ try:
948
+ body = await request.json()
949
+ except Exception:
950
+ return web.json_response({"error": "Invalid JSON body"}, status=400)
951
+
952
+ namespace = body.get("namespace", "default")
953
+ ids = body.get("ids")
954
+ if not isinstance(ids, list) or not ids:
955
+ return web.json_response({"error": "'ids' must be a non-empty array"}, status=400)
956
+
957
+ try:
958
+ removed = await _vector_index.remove(namespace, ids)
959
+ return web.json_response({"ok": True, "removed": removed})
960
+ except Exception as e:
961
+ logger.exception("Remove failed")
962
+ return web.json_response({"error": f"Remove failed: {e}"}, status=500)
963
+
964
+
965
+ async def handle_purge(request: web.Request) -> web.Response:
966
+ """POST /purge — Remove all vectors in a namespace."""
967
+ if _vector_index is None:
968
+ return web.json_response({"error": "Not initialized"}, status=503)
969
+
970
+ try:
971
+ body = await request.json()
972
+ except Exception:
973
+ return web.json_response({"error": "Invalid JSON body"}, status=400)
974
+
975
+ namespace = body.get("namespace")
976
+ if not isinstance(namespace, str) or not namespace.strip():
977
+ return web.json_response({"error": "'namespace' must be a non-empty string"}, status=400)
978
+
979
+ try:
980
+ removed = await _vector_index.purge_namespace(namespace)
981
+ return web.json_response({"ok": True, "removed": removed})
982
+ except Exception as e:
983
+ logger.exception("Purge failed")
984
+ return web.json_response({"error": f"Purge failed: {e}"}, status=500)
985
+
986
+
987
+ async def run_http_server(port: int) -> None:
988
+ """Run the HTTP embedding endpoint alongside MCP stdio."""
989
+ app = web.Application()
990
+ app.router.add_post("/embed", handle_embed)
991
+ if EMBEDDING_INDEX_ENABLED and _vector_index is not None:
992
+ app.router.add_post("/index", handle_index)
993
+ app.router.add_post("/search", handle_search)
994
+ app.router.add_post("/remove", handle_remove)
995
+ app.router.add_post("/purge", handle_purge)
996
+
997
+ runner = web.AppRunner(app, access_log=None)
998
+ await runner.setup()
999
+ site = web.TCPSite(runner, "127.0.0.1", port)
1000
+ await site.start()
1001
+ logger.info("HTTP embedding endpoint started on http://127.0.0.1:%d/embed", port)
1002
+
1003
+ try:
1004
+ # Block until cancelled
1005
+ await asyncio.Event().wait()
1006
+ finally:
1007
+ await runner.cleanup()
1008
+
1009
+
1010
+ # ============================================================
1011
+ # MCP Server
1012
+ # ============================================================
1013
+
1014
+ registry = ToolRegistry("cloto-mcp-embedding")
1015
+
1016
+
1017
+ @registry.tool(
1018
+ "embed",
1019
+ "Generate vector embeddings for input texts.",
1020
+ {
1021
+ "type": "object",
1022
+ "properties": {
1023
+ "texts": {
1024
+ "type": "array",
1025
+ "items": {"type": "string"},
1026
+ "description": "Texts to embed (batch, max 100)",
1027
+ }
1028
+ },
1029
+ "required": ["texts"],
1030
+ },
1031
+ )
1032
+ async def handle_embed_tool(arguments: dict) -> dict:
1033
+ if _provider is None:
1034
+ return {"error": "Provider not initialized"}
1035
+
1036
+ texts = arguments.get("texts", [])
1037
+ if not isinstance(texts, list) or not texts:
1038
+ return {"error": "'texts' must be a non-empty array"}
1039
+
1040
+ if len(texts) > 100:
1041
+ return {"error": "Batch size exceeds limit (max 100)"}
1042
+
1043
+ try:
1044
+ embeddings = await _provider.embed(texts)
1045
+ return {
1046
+ "embeddings": embeddings,
1047
+ "dimensions": _provider.dimensions(),
1048
+ }
1049
+ except Exception as e:
1050
+ return {"error": str(e)}
1051
+
1052
+
1053
+ @registry.tool(
1054
+ "index",
1055
+ "Index text items for vector similarity search. Each item gets embedded and stored persistently.",
1056
+ {
1057
+ "type": "object",
1058
+ "properties": {
1059
+ "namespace": {
1060
+ "type": "string",
1061
+ "description": "Namespace for isolation (e.g., 'cpersona:agent-id')",
1062
+ "default": "default",
1063
+ },
1064
+ "items": {
1065
+ "type": "array",
1066
+ "items": {
1067
+ "type": "object",
1068
+ "properties": {
1069
+ "id": {"type": "string"},
1070
+ "text": {"type": "string"},
1071
+ },
1072
+ "required": ["id", "text"],
1073
+ },
1074
+ "description": "Items to index (max 100)",
1075
+ },
1076
+ },
1077
+ "required": ["items"],
1078
+ },
1079
+ )
1080
+ async def handle_index_tool(arguments: dict) -> dict:
1081
+ if _provider is None or _vector_index is None:
1082
+ return {"error": "Not initialized or index disabled"}
1083
+
1084
+ namespace = arguments.get("namespace", "default")
1085
+ items = arguments.get("items", [])
1086
+ if not items or len(items) > 100:
1087
+ return {"error": "items must be 1-100 entries"}
1088
+
1089
+ try:
1090
+ indexed = await _vector_index.index(namespace, items, _provider)
1091
+ return {"ok": True, "indexed": indexed}
1092
+ except Exception as e:
1093
+ return {"error": str(e)}
1094
+
1095
+
1096
+ @registry.tool(
1097
+ "search",
1098
+ "Search indexed vectors by semantic similarity. Returns top-K results with scores.",
1099
+ {
1100
+ "type": "object",
1101
+ "properties": {
1102
+ "namespace": {
1103
+ "type": "string",
1104
+ "description": "Namespace to search within",
1105
+ "default": "default",
1106
+ },
1107
+ "query": {
1108
+ "type": "string",
1109
+ "description": "Search query text",
1110
+ },
1111
+ "limit": {
1112
+ "type": "integer",
1113
+ "description": "Max results to return (default: 10)",
1114
+ "default": 10,
1115
+ },
1116
+ "min_similarity": {
1117
+ "type": "number",
1118
+ "description": "Minimum cosine similarity threshold (default: 0.3)",
1119
+ "default": 0.3,
1120
+ },
1121
+ },
1122
+ "required": ["query"],
1123
+ },
1124
+ )
1125
+ async def handle_search_tool(arguments: dict) -> dict:
1126
+ if _provider is None or _vector_index is None:
1127
+ return {"error": "Not initialized or index disabled"}
1128
+
1129
+ namespace = arguments.get("namespace", "default")
1130
+ query = arguments.get("query", "")
1131
+ limit = min(int(arguments.get("limit", 10)), 500)
1132
+ min_similarity = float(arguments.get("min_similarity", 0.3))
1133
+
1134
+ if not query.strip():
1135
+ return {"error": "query must be non-empty"}
1136
+
1137
+ try:
1138
+ results = await _vector_index.search(namespace, query, limit, min_similarity, _provider)
1139
+ return {"results": results}
1140
+ except Exception as e:
1141
+ return {"error": str(e)}
1142
+
1143
+
1144
+ @registry.tool(
1145
+ "remove",
1146
+ "Remove items from the vector index by ID.",
1147
+ {
1148
+ "type": "object",
1149
+ "properties": {
1150
+ "namespace": {
1151
+ "type": "string",
1152
+ "description": "Namespace containing the items",
1153
+ "default": "default",
1154
+ },
1155
+ "ids": {
1156
+ "type": "array",
1157
+ "items": {"type": "string"},
1158
+ "description": "Item IDs to remove",
1159
+ },
1160
+ },
1161
+ "required": ["ids"],
1162
+ },
1163
+ )
1164
+ async def handle_remove_tool(arguments: dict) -> dict:
1165
+ if _vector_index is None:
1166
+ return {"error": "Index not initialized or disabled"}
1167
+
1168
+ namespace = arguments.get("namespace", "default")
1169
+ ids = arguments.get("ids", [])
1170
+ if not ids:
1171
+ return {"error": "ids must be non-empty"}
1172
+
1173
+ try:
1174
+ removed = await _vector_index.remove(namespace, ids)
1175
+ return {"ok": True, "removed": removed}
1176
+ except Exception as e:
1177
+ return {"error": str(e)}
1178
+
1179
+
1180
+ @registry.tool(
1181
+ "purge",
1182
+ "Remove ALL vectors in a namespace. Use for bulk cleanup (e.g., agent deletion).",
1183
+ {
1184
+ "type": "object",
1185
+ "properties": {
1186
+ "namespace": {
1187
+ "type": "string",
1188
+ "description": "Namespace to purge completely",
1189
+ },
1190
+ },
1191
+ "required": ["namespace"],
1192
+ },
1193
+ )
1194
+ async def handle_purge_tool(arguments: dict) -> dict:
1195
+ if _vector_index is None:
1196
+ return {"error": "Index not initialized or disabled"}
1197
+
1198
+ namespace = arguments.get("namespace", "")
1199
+ if not namespace:
1200
+ return {"error": "namespace is required"}
1201
+
1202
+ try:
1203
+ removed = await _vector_index.purge_namespace(namespace)
1204
+ return {"ok": True, "removed": removed}
1205
+ except Exception as e:
1206
+ return {"error": str(e)}
1207
+
1208
+
1209
+ # ============================================================
1210
+ # Main
1211
+ # ============================================================
1212
+
1213
+
1214
+ async def _run_streamable_http() -> None:
1215
+ """Run embedding as a Streamable HTTP MCP server (no stdio, no REST /embed).
1216
+
1217
+ Enabled by setting EMBEDDING_TRANSPORT=streamable-http. Listens on
1218
+ EMBEDDING_MCP_HTTP_PORT (default 8403) and mounts the MCP endpoint at
1219
+ /embedding/mcp and /embedding so it can coexist with other services
1220
+ behind path-based reverse proxies.
1221
+ """
1222
+ global _provider, _vector_index
1223
+
1224
+ import contextlib
1225
+ from collections.abc import AsyncIterator
1226
+
1227
+ import uvicorn
1228
+ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
1229
+ from starlette.applications import Starlette
1230
+ from starlette.middleware import Middleware
1231
+ from starlette.middleware.cors import CORSMiddleware
1232
+ from starlette.routing import Mount
1233
+
1234
+ _provider = create_provider()
1235
+ await _provider.initialize()
1236
+
1237
+ if EMBEDDING_INDEX_ENABLED:
1238
+ _vector_index = VectorIndex(EMBEDDING_INDEX_DB_PATH)
1239
+ await _vector_index.initialize()
1240
+
1241
+ session_manager = StreamableHTTPSessionManager(
1242
+ app=registry.server,
1243
+ stateless=True,
1244
+ )
1245
+
1246
+ async def mcp_endpoint(scope, receive, send):
1247
+ await session_manager.handle_request(scope, receive, send)
1248
+
1249
+ @contextlib.asynccontextmanager
1250
+ async def lifespan(_app: Starlette) -> AsyncIterator[None]:
1251
+ async with session_manager.run():
1252
+ logger.info("Embedding Streamable HTTP server ready")
1253
+ yield
1254
+
1255
+ app = Starlette(
1256
+ routes=[
1257
+ Mount("/embedding/mcp", app=mcp_endpoint),
1258
+ Mount("/embedding", app=mcp_endpoint),
1259
+ Mount("/mcp", app=mcp_endpoint),
1260
+ Mount("/", app=mcp_endpoint),
1261
+ ],
1262
+ middleware=[
1263
+ Middleware(
1264
+ CORSMiddleware,
1265
+ allow_origins=["https://claude.ai", "https://www.claude.ai"],
1266
+ allow_methods=["GET", "POST", "DELETE", "OPTIONS"],
1267
+ allow_headers=[
1268
+ "Authorization",
1269
+ "Content-Type",
1270
+ "Mcp-Session-Id",
1271
+ "Mcp-Protocol-Version",
1272
+ "Last-Event-Id",
1273
+ ],
1274
+ expose_headers=["Mcp-Session-Id"],
1275
+ ),
1276
+ ],
1277
+ lifespan=lifespan,
1278
+ )
1279
+
1280
+ host = os.environ.get("EMBEDDING_MCP_HTTP_HOST", "0.0.0.0")
1281
+ port = int(os.environ.get("EMBEDDING_MCP_HTTP_PORT", "8403"))
1282
+ logger.info("Starting Embedding Streamable HTTP MCP on %s:%d", host, port)
1283
+
1284
+ config = uvicorn.Config(app, host=host, port=port, log_level="info")
1285
+ server = uvicorn.Server(config)
1286
+ try:
1287
+ await server.serve()
1288
+ finally:
1289
+ if _vector_index:
1290
+ await _vector_index.shutdown()
1291
+ await _provider.shutdown()
1292
+ logger.info("Embedding Streamable HTTP server shut down")
1293
+
1294
+
1295
+ async def main():
1296
+ global _provider, _vector_index
1297
+
1298
+ logging.basicConfig(
1299
+ level=logging.INFO,
1300
+ format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
1301
+ )
1302
+
1303
+ transport = os.environ.get("EMBEDDING_TRANSPORT", "stdio")
1304
+ if transport == "streamable-http":
1305
+ await _run_streamable_http()
1306
+ return
1307
+
1308
+ logger.info(
1309
+ "Starting embedding server (provider=%s, http_port=%d, index=%s)",
1310
+ EMBEDDING_PROVIDER,
1311
+ EMBEDDING_HTTP_PORT,
1312
+ "enabled" if EMBEDDING_INDEX_ENABLED else "disabled",
1313
+ )
1314
+
1315
+ _provider = create_provider()
1316
+ await _provider.initialize()
1317
+
1318
+ # Initialize vector index if enabled
1319
+ if EMBEDDING_INDEX_ENABLED:
1320
+ _vector_index = VectorIndex(EMBEDDING_INDEX_DB_PATH)
1321
+ await _vector_index.initialize()
1322
+
1323
+ # Start HTTP endpoint as background task
1324
+ http_task = asyncio.create_task(run_http_server(EMBEDDING_HTTP_PORT))
1325
+
1326
+ try:
1327
+ async with stdio_server() as (read_stream, write_stream):
1328
+ await registry.server.run(
1329
+ read_stream,
1330
+ write_stream,
1331
+ registry.server.create_initialization_options(),
1332
+ )
1333
+ finally:
1334
+ http_task.cancel()
1335
+ try:
1336
+ await http_task
1337
+ except asyncio.CancelledError:
1338
+ pass
1339
+ if _vector_index:
1340
+ await _vector_index.shutdown()
1341
+ await _provider.shutdown()
1342
+ logger.info("Embedding server shut down")
1343
+
1344
+
1345
+ def run():
1346
+ """Console-script / ``python -m cembedding`` entry point (sync wrapper)."""
1347
+ asyncio.run(main())
1348
+
1349
+
1350
+ if __name__ == "__main__":
1351
+ run()