grepmax 0.2.2 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "mlx-embed-server"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "MLX-accelerated embedding server for grepmax"
|
|
5
|
+
requires-python = ">=3.13"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"fastapi>=0.115.0",
|
|
8
|
+
"uvicorn>=0.34.0",
|
|
9
|
+
"mlx-embeddings @ git+https://github.com/Blaizzy/mlx-embeddings.git",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
[project.scripts]
|
|
13
|
+
mlx-embed-server = "server:main"
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""MLX-accelerated embedding server for grepmax.
|
|
2
|
+
|
|
3
|
+
Serves granite-embedding-small-english-r2 on Apple Silicon GPU via MLX.
|
|
4
|
+
gmax workers call POST /embed with {"texts": [...]} and get back {"vectors": [...]}.
|
|
5
|
+
Falls through to ONNX CPU if this server isn't running.
|
|
6
|
+
|
|
7
|
+
IMPORTANT: All MLX operations must run on a single thread. FastAPI async
|
|
8
|
+
endpoints run on the event loop thread, avoiding the Metal thread-safety
|
|
9
|
+
crashes that occur when uvicorn's sync threadpool dispatches concurrent
|
|
10
|
+
GPU operations.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import asyncio
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
import signal
|
|
17
|
+
import socket
|
|
18
|
+
import time
|
|
19
|
+
import warnings
|
|
20
|
+
from contextlib import asynccontextmanager
|
|
21
|
+
|
|
22
|
+
# Suppress all HF/transformers/tqdm noise before any imports touch them
|
|
23
|
+
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
|
24
|
+
os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1"
|
|
25
|
+
os.environ["HF_HUB_VERBOSITY"] = "error"
|
|
26
|
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|
27
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
28
|
+
warnings.filterwarnings("ignore", message=".*PyTorch.*")
|
|
29
|
+
warnings.filterwarnings("ignore", message=".*resource_tracker.*")
|
|
30
|
+
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
import mlx.core as mx
|
|
36
|
+
import uvicorn
|
|
37
|
+
from fastapi import FastAPI
|
|
38
|
+
from mlx_embeddings import load
|
|
39
|
+
from pydantic import BaseModel
|
|
40
|
+
from transformers import AutoTokenizer
|
|
41
|
+
|
|
42
|
+
MODEL_ID = os.environ.get(
|
|
43
|
+
"MLX_EMBED_MODEL", "ibm-granite/granite-embedding-small-english-r2"
|
|
44
|
+
)
|
|
45
|
+
PORT = int(os.environ.get("MLX_EMBED_PORT", "8100"))
|
|
46
|
+
MAX_BATCH = int(os.environ.get("MLX_EMBED_MAX_BATCH", "64"))
|
|
47
|
+
IDLE_TIMEOUT_S = int(os.environ.get("MLX_EMBED_IDLE_TIMEOUT", "1800")) # 30 min
|
|
48
|
+
|
|
49
|
+
model = None
|
|
50
|
+
tokenizer = None
|
|
51
|
+
last_activity = time.time()
|
|
52
|
+
|
|
53
|
+
# Serialize all MLX GPU operations — Metal is not thread-safe
|
|
54
|
+
_mlx_lock = asyncio.Lock()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def is_port_in_use(port: int) -> bool:
|
|
58
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
59
|
+
return s.connect_ex(("127.0.0.1", port)) == 0
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def embed_texts(texts: list[str]) -> mx.array:
|
|
63
|
+
"""Tokenize, forward pass, L2 normalize.
|
|
64
|
+
|
|
65
|
+
mlx_embeddings model already does mean pooling internally —
|
|
66
|
+
last_hidden_state is (batch, dim), not (batch, seq, dim).
|
|
67
|
+
"""
|
|
68
|
+
encoded = tokenizer(
|
|
69
|
+
texts, padding=True, truncation=True, max_length=256, return_tensors="np"
|
|
70
|
+
)
|
|
71
|
+
input_ids = mx.array(encoded["input_ids"])
|
|
72
|
+
attention_mask = mx.array(encoded["attention_mask"])
|
|
73
|
+
|
|
74
|
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
|
75
|
+
|
|
76
|
+
# text_embeds is the pooled output; fall back to last_hidden_state
|
|
77
|
+
if hasattr(outputs, "text_embeds") and outputs.text_embeds is not None:
|
|
78
|
+
pooled = outputs.text_embeds
|
|
79
|
+
else:
|
|
80
|
+
pooled = outputs.last_hidden_state
|
|
81
|
+
|
|
82
|
+
# L2 normalize
|
|
83
|
+
norms = mx.sqrt(mx.sum(pooled * pooled, axis=-1, keepdims=True))
|
|
84
|
+
norms = mx.maximum(norms, 1e-12)
|
|
85
|
+
normalized = pooled / norms
|
|
86
|
+
mx.eval(normalized)
|
|
87
|
+
return normalized
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def load_model():
|
|
91
|
+
global model, tokenizer
|
|
92
|
+
print(f"[mlx-embed] Loading {MODEL_ID}...")
|
|
93
|
+
model, _ = load(MODEL_ID)
|
|
94
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
95
|
+
_ = embed_texts(["warm up"])
|
|
96
|
+
print("[mlx-embed] Model ready on Metal GPU.")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@asynccontextmanager
|
|
100
|
+
async def lifespan(app: FastAPI):
|
|
101
|
+
load_model()
|
|
102
|
+
yield
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
app = FastAPI(lifespan=lifespan)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class EmbedRequest(BaseModel):
|
|
109
|
+
texts: list[str]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class EmbedResponse(BaseModel):
|
|
113
|
+
vectors: list[list[float]]
|
|
114
|
+
dim: int
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@app.post("/embed")
|
|
118
|
+
async def embed(request: EmbedRequest) -> EmbedResponse:
|
|
119
|
+
global last_activity
|
|
120
|
+
last_activity = time.time()
|
|
121
|
+
|
|
122
|
+
texts = request.texts[:MAX_BATCH]
|
|
123
|
+
|
|
124
|
+
async with _mlx_lock:
|
|
125
|
+
vectors = embed_texts(texts)
|
|
126
|
+
vectors_list = vectors.tolist()
|
|
127
|
+
|
|
128
|
+
return EmbedResponse(
|
|
129
|
+
vectors=vectors_list,
|
|
130
|
+
dim=len(vectors_list[0]) if vectors_list else 0,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@app.get("/health")
|
|
135
|
+
async def health():
|
|
136
|
+
global last_activity
|
|
137
|
+
last_activity = time.time()
|
|
138
|
+
return {"status": "ok", "model": MODEL_ID}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def main():
|
|
142
|
+
# Bail early if port is already taken
|
|
143
|
+
if is_port_in_use(PORT):
|
|
144
|
+
print(f"[mlx-embed] Port {PORT} already in use — server is already running.")
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
print(f"[mlx-embed] Starting on port {PORT}")
|
|
148
|
+
|
|
149
|
+
# Clean shutdown — exit immediately, skip uvicorn's noisy teardown
|
|
150
|
+
def handle_signal(sig, frame):
|
|
151
|
+
print("[mlx-embed] Stopped.")
|
|
152
|
+
# Kill the resource_tracker child process before exit to prevent
|
|
153
|
+
# its spurious "leaked semaphore" warning (Python 3.13 bug)
|
|
154
|
+
try:
|
|
155
|
+
from multiprocessing.resource_tracker import _resource_tracker
|
|
156
|
+
if _resource_tracker._pid is not None:
|
|
157
|
+
os.kill(_resource_tracker._pid, signal.SIGKILL)
|
|
158
|
+
except Exception:
|
|
159
|
+
pass
|
|
160
|
+
os._exit(0)
|
|
161
|
+
|
|
162
|
+
signal.signal(signal.SIGINT, handle_signal)
|
|
163
|
+
signal.signal(signal.SIGTERM, handle_signal)
|
|
164
|
+
|
|
165
|
+
uvicorn.run(app, host="127.0.0.1", port=PORT, log_level="warning")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
if __name__ == "__main__":
|
|
169
|
+
main()
|