grepmax 0.2.1 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/config.js CHANGED
@@ -142,7 +142,6 @@ exports.INDEXABLE_EXTENSIONS = new Set([
142
142
  ".xml",
143
143
  ".md",
144
144
  ".mdx",
145
- ".txt",
146
145
  ".gitignore",
147
146
  ".dockerfile",
148
147
  "dockerfile",
@@ -22,6 +22,22 @@ exports.DEFAULT_IGNORE_PATTERNS = [
22
22
  "**/__pycache__/**",
23
23
  "**/coverage/**",
24
24
  "**/venv/**",
25
+ "**/.venv/**",
26
+ "**/.tox/**",
27
+ "**/.mypy_cache/**",
28
+ "**/.pytest_cache/**",
29
+ "**/.next/**",
30
+ "**/.nuxt/**",
31
+ "**/.gradle/**",
32
+ "**/.m2/**",
33
+ "**/vendor/**",
34
+ "**/.osgrep/**",
35
+ "**/.gmax/**",
36
+ // Minified/generated assets
37
+ "*.min.js",
38
+ "*.min.css",
39
+ "*.map",
40
+ "*.wasm",
25
41
  // Test fixtures and benchmark data
26
42
  "**/fixtures/**",
27
43
  "**/benchmark/**",
@@ -0,0 +1,13 @@
1
+ [project]
2
+ name = "mlx-embed-server"
3
+ version = "0.1.0"
4
+ description = "MLX-accelerated embedding server for grepmax"
5
+ requires-python = ">=3.13"
6
+ dependencies = [
7
+ "fastapi>=0.115.0",
8
+ "uvicorn>=0.34.0",
9
+ "mlx-embeddings @ git+https://github.com/Blaizzy/mlx-embeddings.git",
10
+ ]
11
+
12
+ [project.scripts]
13
+ mlx-embed-server = "server:main"
@@ -0,0 +1,169 @@
1
+ """MLX-accelerated embedding server for grepmax.
2
+
3
+ Serves granite-embedding-small-english-r2 on Apple Silicon GPU via MLX.
4
+ gmax workers call POST /embed with {"texts": [...]} and get back {"vectors": [...]}.
5
+ Falls through to ONNX CPU if this server isn't running.
6
+
7
+ IMPORTANT: All MLX operations must run on a single thread. FastAPI async
8
+ endpoints run on the event loop thread, avoiding the Metal thread-safety
9
+ crashes that occur when uvicorn's sync threadpool dispatches concurrent
10
+ GPU operations.
11
+ """
12
+
13
+ import asyncio
14
+ import logging
15
+ import os
16
+ import signal
17
+ import socket
18
+ import time
19
+ import warnings
20
+ from contextlib import asynccontextmanager
21
+
22
+ # Suppress all HF/transformers/tqdm noise before any imports touch them
23
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
24
+ os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1"
25
+ os.environ["HF_HUB_VERBOSITY"] = "error"
26
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
27
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
28
+ warnings.filterwarnings("ignore", message=".*PyTorch.*")
29
+ warnings.filterwarnings("ignore", message=".*resource_tracker.*")
30
+ logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
31
+
32
+
33
+
34
+
35
+ import mlx.core as mx
36
+ import uvicorn
37
+ from fastapi import FastAPI
38
+ from mlx_embeddings import load
39
+ from pydantic import BaseModel
40
+ from transformers import AutoTokenizer
41
+
42
+ MODEL_ID = os.environ.get(
43
+ "MLX_EMBED_MODEL", "ibm-granite/granite-embedding-small-english-r2"
44
+ )
45
+ PORT = int(os.environ.get("MLX_EMBED_PORT", "8100"))
46
+ MAX_BATCH = int(os.environ.get("MLX_EMBED_MAX_BATCH", "64"))
47
+ IDLE_TIMEOUT_S = int(os.environ.get("MLX_EMBED_IDLE_TIMEOUT", "1800")) # 30 min
48
+
49
+ model = None
50
+ tokenizer = None
51
+ last_activity = time.time()
52
+
53
+ # Serialize all MLX GPU operations — Metal is not thread-safe
54
+ _mlx_lock = asyncio.Lock()
55
+
56
+
57
+ def is_port_in_use(port: int) -> bool:
58
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
59
+ return s.connect_ex(("127.0.0.1", port)) == 0
60
+
61
+
62
+ def embed_texts(texts: list[str]) -> mx.array:
63
+ """Tokenize, forward pass, L2 normalize.
64
+
65
+ mlx_embeddings model already does mean pooling internally —
66
+ last_hidden_state is (batch, dim), not (batch, seq, dim).
67
+ """
68
+ encoded = tokenizer(
69
+ texts, padding=True, truncation=True, max_length=256, return_tensors="np"
70
+ )
71
+ input_ids = mx.array(encoded["input_ids"])
72
+ attention_mask = mx.array(encoded["attention_mask"])
73
+
74
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
75
+
76
+ # text_embeds is the pooled output; fall back to last_hidden_state
77
+ if hasattr(outputs, "text_embeds") and outputs.text_embeds is not None:
78
+ pooled = outputs.text_embeds
79
+ else:
80
+ pooled = outputs.last_hidden_state
81
+
82
+ # L2 normalize
83
+ norms = mx.sqrt(mx.sum(pooled * pooled, axis=-1, keepdims=True))
84
+ norms = mx.maximum(norms, 1e-12)
85
+ normalized = pooled / norms
86
+ mx.eval(normalized)
87
+ return normalized
88
+
89
+
90
+ def load_model():
91
+ global model, tokenizer
92
+ print(f"[mlx-embed] Loading {MODEL_ID}...")
93
+ model, _ = load(MODEL_ID)
94
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
95
+ _ = embed_texts(["warm up"])
96
+ print("[mlx-embed] Model ready on Metal GPU.")
97
+
98
+
99
+ @asynccontextmanager
100
+ async def lifespan(app: FastAPI):
101
+ load_model()
102
+ yield
103
+
104
+
105
+ app = FastAPI(lifespan=lifespan)
106
+
107
+
108
+ class EmbedRequest(BaseModel):
109
+ texts: list[str]
110
+
111
+
112
+ class EmbedResponse(BaseModel):
113
+ vectors: list[list[float]]
114
+ dim: int
115
+
116
+
117
+ @app.post("/embed")
118
+ async def embed(request: EmbedRequest) -> EmbedResponse:
119
+ global last_activity
120
+ last_activity = time.time()
121
+
122
+ texts = request.texts[:MAX_BATCH]
123
+
124
+ async with _mlx_lock:
125
+ vectors = embed_texts(texts)
126
+ vectors_list = vectors.tolist()
127
+
128
+ return EmbedResponse(
129
+ vectors=vectors_list,
130
+ dim=len(vectors_list[0]) if vectors_list else 0,
131
+ )
132
+
133
+
134
+ @app.get("/health")
135
+ async def health():
136
+ global last_activity
137
+ last_activity = time.time()
138
+ return {"status": "ok", "model": MODEL_ID}
139
+
140
+
141
+ def main():
142
+ # Bail early if port is already taken
143
+ if is_port_in_use(PORT):
144
+ print(f"[mlx-embed] Port {PORT} already in use — server is already running.")
145
+ return
146
+
147
+ print(f"[mlx-embed] Starting on port {PORT}")
148
+
149
+ # Clean shutdown — exit immediately, skip uvicorn's noisy teardown
150
+ def handle_signal(sig, frame):
151
+ print("[mlx-embed] Stopped.")
152
+ # Kill the resource_tracker child process before exit to prevent
153
+ # its spurious "leaked semaphore" warning (Python 3.13 bug)
154
+ try:
155
+ from multiprocessing.resource_tracker import _resource_tracker
156
+ if _resource_tracker._pid is not None:
157
+ os.kill(_resource_tracker._pid, signal.SIGKILL)
158
+ except Exception:
159
+ pass
160
+ os._exit(0)
161
+
162
+ signal.signal(signal.SIGINT, handle_signal)
163
+ signal.signal(signal.SIGTERM, handle_signal)
164
+
165
+ uvicorn.run(app, host="127.0.0.1", port=PORT, log_level="warning")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()