hpc-as-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hpc_as_api/__init__.py +54 -0
- hpc_as_api/app.py +553 -0
- hpc_as_api/auth.py +459 -0
- hpc_as_api/compute.py +943 -0
- hpc_as_api/crypto.py +127 -0
- hpc_as_api/utils.py +202 -0
- hpc_as_api-0.1.0.dist-info/METADATA +344 -0
- hpc_as_api-0.1.0.dist-info/RECORD +11 -0
- hpc_as_api-0.1.0.dist-info/WHEEL +4 -0
- hpc_as_api-0.1.0.dist-info/entry_points.txt +2 -0
- hpc_as_api-0.1.0.dist-info/licenses/LICENSE +153 -0
hpc_as_api/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""
|
|
2
|
+
hpc-as-api — OpenAI-compatible API gateway for HPC clusters via Globus Compute.
|
|
3
|
+
|
|
4
|
+
Quick start (programmatic use):
|
|
5
|
+
from hpc_as_api.compute import GlobusComputeClient
|
|
6
|
+
|
|
7
|
+
client = GlobusComputeClient(
|
|
8
|
+
endpoint_id="8d978809-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
|
9
|
+
models={
|
|
10
|
+
"qwen25-vl-72b": {
|
|
11
|
+
"hf_name": "Qwen/Qwen2.5-VL-72B-Instruct-AWQ",
|
|
12
|
+
"url": "http://ghi2-002:8000",
|
|
13
|
+
"context_reserve_output": 4096,
|
|
14
|
+
}
|
|
15
|
+
},
|
|
16
|
+
)
|
|
17
|
+
result = await client.submit_inference(
|
|
18
|
+
messages=[{"role": "user", "content": "Hello!"}],
|
|
19
|
+
model="qwen25-vl-72b",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
Quick start (FastAPI service):
|
|
23
|
+
# Set env vars: GLOBUS_COMPUTE_ENDPOINT_ID, HPC_MODELS, RELAY_URL, ...
|
|
24
|
+
# Then run: uvicorn hpc_as_api.app:app --host 0.0.0.0 --port 8001
|
|
25
|
+
|
|
26
|
+
# Or embed the router in your existing FastAPI app:
|
|
27
|
+
from hpc_as_api.app import router
|
|
28
|
+
app.include_router(router, prefix="/hpc")
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from hpc_as_api.utils import (
|
|
32
|
+
count_images,
|
|
33
|
+
extract_text_content,
|
|
34
|
+
has_images,
|
|
35
|
+
strip_old_images,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# GlobusComputeClient depends on globus_compute_sdk and globus_sdk, which are
|
|
39
|
+
# optional (hpc-as-api[globus]). Import lazily so the base package works
|
|
40
|
+
# without them installed.
|
|
41
|
+
try:
|
|
42
|
+
from hpc_as_api.compute import GlobusComputeClient
|
|
43
|
+
_GLOBUS_AVAILABLE = True
|
|
44
|
+
except ImportError:
|
|
45
|
+
_GLOBUS_AVAILABLE = False
|
|
46
|
+
|
|
47
|
+
__version__ = "0.1.0"
|
|
48
|
+
__all__ = [
|
|
49
|
+
"GlobusComputeClient",
|
|
50
|
+
"extract_text_content",
|
|
51
|
+
"has_images",
|
|
52
|
+
"count_images",
|
|
53
|
+
"strip_old_images",
|
|
54
|
+
]
|
hpc_as_api/app.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HPC Gateway — FastAPI application for routing LLM requests to HPC clusters.
|
|
3
|
+
|
|
4
|
+
DUAL-USE DESIGN:
|
|
5
|
+
----------------
|
|
6
|
+
This module serves two roles:
|
|
7
|
+
|
|
8
|
+
1. STANDALONE SERVICE:
|
|
9
|
+
Run as its own process with uvicorn. The caller sends OpenAI-compatible
|
|
10
|
+
/v1/chat/completions requests; the gateway dispatches them to the HPC cluster
|
|
11
|
+
via Globus Compute and streams tokens back through the WebSocket relay.
|
|
12
|
+
→ Start with: uvicorn hpc_as_api.app:app --host 0.0.0.0 --port 8001
|
|
13
|
+
|
|
14
|
+
2. EMBEDDED ROUTER:
|
|
15
|
+
Import `router` and mount it in any existing FastAPI application:
|
|
16
|
+
from hpc_as_api.app import router
|
|
17
|
+
app.include_router(router, prefix="/hpc")
|
|
18
|
+
Same routes, same logic, no separate process needed.
|
|
19
|
+
|
|
20
|
+
CONFIGURATION:
|
|
21
|
+
--------------
|
|
22
|
+
All settings come from environment variables. No config files to manage:
|
|
23
|
+
|
|
24
|
+
GLOBUS_COMPUTE_ENDPOINT_ID UUID of the HPC cluster's Globus endpoint
|
|
25
|
+
HPC_MODELS JSON dict mapping model names to their config
|
|
26
|
+
(see GlobusComputeClient for the schema)
|
|
27
|
+
RELAY_URL WebSocket URL of the relay server
|
|
28
|
+
RELAY_SECRET Shared secret for relay authentication
|
|
29
|
+
RELAY_ENCRYPTION_KEY AES-256 key (hex) for E2E relay encryption
|
|
30
|
+
|
|
31
|
+
HPC_PROXY_HOST Host to bind to (default: 0.0.0.0)
|
|
32
|
+
HPC_PROXY_PORT Port to listen on (default: 8001)
|
|
33
|
+
USE_GLOBUS_COMPUTE "true"/"false" (default: true)
|
|
34
|
+
VLLM_SERVER_URL Fallback vLLM URL when not using Globus
|
|
35
|
+
LOG_LEVEL Logging level (default: INFO)
|
|
36
|
+
|
|
37
|
+
AUTHENTICATION:
|
|
38
|
+
---------------
|
|
39
|
+
Every /v1/* endpoint requires authentication. Two modes are supported:
|
|
40
|
+
- Globus token: Bearer token from Globus Auth (checked via introspection)
|
|
41
|
+
- API key: Static API key from the HPC_API_KEYS env var
|
|
42
|
+
|
|
43
|
+
Set USE_GLOBUS_AUTH=true to enable Globus token validation.
|
|
44
|
+
Set HPC_API_KEYS to a comma-separated list of valid API keys.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
import asyncio
|
|
48
|
+
import json
|
|
49
|
+
import logging
|
|
50
|
+
import os
|
|
51
|
+
from contextlib import asynccontextmanager
|
|
52
|
+
|
|
53
|
+
import httpx
|
|
54
|
+
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
|
55
|
+
from fastapi.responses import StreamingResponse
|
|
56
|
+
|
|
57
|
+
from hpc_as_api.auth import CallerIdentity, authenticate, validate_messages
|
|
58
|
+
from hpc_as_api.compute import GlobusComputeClient
|
|
59
|
+
from hpc_as_api.crypto import decrypt_message
|
|
60
|
+
|
|
61
|
+
# =========================================================================
|
|
62
|
+
# Configuration — all from environment variables
|
|
63
|
+
# =========================================================================
|
|
64
|
+
PROXY_HOST = os.getenv("HPC_PROXY_HOST", "0.0.0.0")
|
|
65
|
+
PROXY_PORT = int(os.getenv("HPC_PROXY_PORT", "8001"))
|
|
66
|
+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
|
67
|
+
|
|
68
|
+
USE_GLOBUS_COMPUTE = os.getenv("USE_GLOBUS_COMPUTE", "true").lower() == "true"
|
|
69
|
+
GLOBUS_COMPUTE_ENDPOINT_ID = os.getenv("GLOBUS_COMPUTE_ENDPOINT_ID")
|
|
70
|
+
|
|
71
|
+
# Fallback vLLM URL (used when USE_GLOBUS_COMPUTE=false — direct SSH tunnel mode)
|
|
72
|
+
LAKESHORE_VLLM_ENDPOINT = os.getenv("LAKESHORE_VLLM_ENDPOINT", "http://localhost:8000")
|
|
73
|
+
|
|
74
|
+
# WebSocket relay for token streaming
|
|
75
|
+
RELAY_URL = os.getenv("RELAY_URL", "")
|
|
76
|
+
RELAY_SECRET = os.getenv("RELAY_SECRET", "")
|
|
77
|
+
RELAY_ENCRYPTION_KEY = os.getenv("RELAY_ENCRYPTION_KEY", "")
|
|
78
|
+
|
|
79
|
+
# Model registry — JSON string mapping model names to their HPC-side config.
|
|
80
|
+
# Example (set as environment variable):
|
|
81
|
+
# HPC_MODELS='{"qwen25-vl-72b": {"hf_name": "Qwen/Qwen2.5-VL-72B-Instruct-AWQ",
|
|
82
|
+
# "url": "http://ghi2-002:8000",
|
|
83
|
+
# "context_reserve_output": 4096}}'
|
|
84
|
+
_HPC_MODELS: dict = {}
|
|
85
|
+
_raw_models = os.getenv("HPC_MODELS", "{}")
|
|
86
|
+
try:
|
|
87
|
+
_HPC_MODELS = json.loads(_raw_models)
|
|
88
|
+
except json.JSONDecodeError:
|
|
89
|
+
logging.warning("HPC_MODELS env var is not valid JSON — no models registered")
|
|
90
|
+
|
|
91
|
+
logger = logging.getLogger(__name__)
|
|
92
|
+
|
|
93
|
+
# =========================================================================
|
|
94
|
+
# Globus Compute client — initialized once at startup
|
|
95
|
+
# =========================================================================
|
|
96
|
+
globus_client: GlobusComputeClient | None = None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _init_globus_client() -> GlobusComputeClient | None:
|
|
100
|
+
"""Create and return a GlobusComputeClient, or None if Globus is disabled/unconfigured."""
|
|
101
|
+
if not USE_GLOBUS_COMPUTE:
|
|
102
|
+
return None
|
|
103
|
+
if not GLOBUS_COMPUTE_ENDPOINT_ID:
|
|
104
|
+
logger.warning(
|
|
105
|
+
"USE_GLOBUS_COMPUTE=true but GLOBUS_COMPUTE_ENDPOINT_ID is not set — "
|
|
106
|
+
"Globus Compute will be unavailable."
|
|
107
|
+
)
|
|
108
|
+
return None
|
|
109
|
+
try:
|
|
110
|
+
client = GlobusComputeClient(
|
|
111
|
+
endpoint_id=GLOBUS_COMPUTE_ENDPOINT_ID,
|
|
112
|
+
models=_HPC_MODELS,
|
|
113
|
+
relay_secret=RELAY_SECRET,
|
|
114
|
+
)
|
|
115
|
+
logger.info("Globus Compute client initialized")
|
|
116
|
+
return client
|
|
117
|
+
except Exception as e:
|
|
118
|
+
logger.error(f"Failed to initialize Globus Compute client: {e}")
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# =========================================================================
|
|
123
|
+
# FastAPI lifespan — replaces deprecated @app.on_event("startup")
|
|
124
|
+
# =========================================================================
|
|
125
|
+
@asynccontextmanager
|
|
126
|
+
async def lifespan(app: FastAPI):
|
|
127
|
+
"""Initialize resources on startup, clean up on shutdown."""
|
|
128
|
+
global globus_client
|
|
129
|
+
globus_client = _init_globus_client()
|
|
130
|
+
|
|
131
|
+
logger.info("=" * 60)
|
|
132
|
+
logger.info("HPC Gateway Starting")
|
|
133
|
+
logger.info("=" * 60)
|
|
134
|
+
logger.info(f"Mode: {'Globus Compute' if USE_GLOBUS_COMPUTE else 'SSH / Direct'}")
|
|
135
|
+
if globus_client:
|
|
136
|
+
logger.info(f"Globus Endpoint: {GLOBUS_COMPUTE_ENDPOINT_ID}")
|
|
137
|
+
logger.info(f"Models: {list(_HPC_MODELS.keys())}")
|
|
138
|
+
if RELAY_URL:
|
|
139
|
+
logger.info(f"Relay: {RELAY_URL}")
|
|
140
|
+
logger.info(f"Listening on: {PROXY_HOST}:{PROXY_PORT}")
|
|
141
|
+
logger.info("=" * 60)
|
|
142
|
+
|
|
143
|
+
yield # Application runs here
|
|
144
|
+
|
|
145
|
+
# Shutdown: clean up the persistent Globus executor
|
|
146
|
+
if globus_client:
|
|
147
|
+
logger.info("Shutting down Globus Compute client...")
|
|
148
|
+
globus_client.shutdown()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# =========================================================================
|
|
152
|
+
# APIRouter — the actual route definitions
|
|
153
|
+
# =========================================================================
|
|
154
|
+
# Using a router lets this module be embedded into any FastAPI app:
|
|
155
|
+
# app.include_router(hpc_as_api.app.router, prefix="/hpc")
|
|
156
|
+
router = APIRouter()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@router.get("/health")
|
|
160
|
+
async def health_check():
|
|
161
|
+
"""Return service health — no authentication required."""
|
|
162
|
+
return {
|
|
163
|
+
"status": "healthy",
|
|
164
|
+
"service": "HPC Gateway",
|
|
165
|
+
"mode": "globus_compute" if USE_GLOBUS_COMPUTE else "direct",
|
|
166
|
+
"globus_configured": bool(globus_client and globus_client.is_available()),
|
|
167
|
+
"models": list(_HPC_MODELS.keys()),
|
|
168
|
+
"relay_configured": bool(RELAY_URL),
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@router.get("/v1/models")
|
|
173
|
+
async def list_models(caller: CallerIdentity = Depends(authenticate)):
|
|
174
|
+
"""
|
|
175
|
+
List available HPC models in OpenAI-compatible format.
|
|
176
|
+
|
|
177
|
+
Returns the same schema as GET /v1/models on OpenAI's API, so any
|
|
178
|
+
OpenAI-compatible client can discover what models are available.
|
|
179
|
+
"""
|
|
180
|
+
from time import time as now
|
|
181
|
+
|
|
182
|
+
models = []
|
|
183
|
+
for name, info in _HPC_MODELS.items():
|
|
184
|
+
models.append(
|
|
185
|
+
{
|
|
186
|
+
"id": info.get("hf_name", name),
|
|
187
|
+
"object": "model",
|
|
188
|
+
"created": int(now()),
|
|
189
|
+
"owned_by": "hpc-as-api",
|
|
190
|
+
# Include the gateway-internal name as metadata.
|
|
191
|
+
# Callers can use either name — hf_name or the registry key.
|
|
192
|
+
"gateway_name": name,
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return {"object": "list", "data": models}
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@router.post("/reload-auth")
|
|
200
|
+
async def reload_authentication():
|
|
201
|
+
"""
|
|
202
|
+
Force reload of Globus credentials from ~/.globus_compute/storage.db.
|
|
203
|
+
|
|
204
|
+
Call this after authenticating on the host machine. The Globus SDK caches
|
|
205
|
+
credentials in a module-level singleton; this clears the cache so fresh
|
|
206
|
+
tokens are re-read from disk.
|
|
207
|
+
"""
|
|
208
|
+
if not USE_GLOBUS_COMPUTE or not globus_client:
|
|
209
|
+
return {"success": False, "message": "Globus Compute not configured"}
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
success, message = globus_client.reload_credentials()
|
|
213
|
+
return {"success": success, "message": message}
|
|
214
|
+
except Exception as e:
|
|
215
|
+
logger.error(f"Failed to reload credentials: {e}")
|
|
216
|
+
return {"success": False, "message": f"Failed to reload: {str(e)}"}
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@router.post("/v1/chat/completions")
|
|
220
|
+
async def proxy_chat_completions(
|
|
221
|
+
request: Request,
|
|
222
|
+
caller: CallerIdentity = Depends(authenticate),
|
|
223
|
+
):
|
|
224
|
+
"""
|
|
225
|
+
Forward a chat completion request to the HPC cluster.
|
|
226
|
+
|
|
227
|
+
Accepts the same request body as OpenAI's POST /v1/chat/completions.
|
|
228
|
+
Responses are also OpenAI-compatible — either a streaming SSE response
|
|
229
|
+
or a complete JSON response depending on the `stream` field.
|
|
230
|
+
"""
|
|
231
|
+
try:
|
|
232
|
+
body = await request.json()
|
|
233
|
+
except Exception as e:
|
|
234
|
+
raise HTTPException(status_code=400, detail="Invalid JSON body") from e
|
|
235
|
+
|
|
236
|
+
raw_model = body.get("model", "")
|
|
237
|
+
# LiteLLM adds an "openai/" prefix when forwarding. Strip it so the model
|
|
238
|
+
# name matches keys in the HPC_MODELS registry.
|
|
239
|
+
model = raw_model.removeprefix("openai/")
|
|
240
|
+
|
|
241
|
+
# Validate and sanitize messages before sending to the HPC cluster.
|
|
242
|
+
messages = validate_messages(body.get("messages", []))
|
|
243
|
+
temperature = body.get("temperature", 0.7)
|
|
244
|
+
stream = body.get("stream", False)
|
|
245
|
+
max_tokens = body.get("max_tokens") # None means "use model default"
|
|
246
|
+
|
|
247
|
+
logger.info(
|
|
248
|
+
f"Chat request: caller={caller.log_safe_id()}, model={model}, "
|
|
249
|
+
f"messages={len(messages)}, stream={stream}"
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
if USE_GLOBUS_COMPUTE:
|
|
253
|
+
return await _route_via_globus_compute(model, messages, temperature, max_tokens, stream)
|
|
254
|
+
else:
|
|
255
|
+
return await _route_via_direct(model, messages, temperature, max_tokens, stream)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
# =========================================================================
|
|
259
|
+
# Internal routing helpers
|
|
260
|
+
# =========================================================================
|
|
261
|
+
|
|
262
|
+
async def _route_via_globus_compute(model, messages, temperature, max_tokens, stream):
|
|
263
|
+
"""Route a request to the HPC cluster via Globus Compute."""
|
|
264
|
+
if not globus_client or not globus_client.is_available():
|
|
265
|
+
raise HTTPException(status_code=503, detail="Globus Compute not configured")
|
|
266
|
+
|
|
267
|
+
# True streaming: submit job, open relay channel, yield tokens as they arrive.
|
|
268
|
+
# Falls back to batch mode if the relay connection fails.
|
|
269
|
+
if stream and RELAY_URL:
|
|
270
|
+
try:
|
|
271
|
+
return await _route_via_globus_compute_streaming(
|
|
272
|
+
model, messages, temperature, max_tokens
|
|
273
|
+
)
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.warning(f"Relay streaming failed — falling back to batch mode: {e}")
|
|
276
|
+
# Fall through to batch mode
|
|
277
|
+
|
|
278
|
+
# Batch mode: submit job, wait for complete result, return as JSON (or simulated stream).
|
|
279
|
+
try:
|
|
280
|
+
logger.info(f"Submitting batch job to Globus endpoint: {GLOBUS_COMPUTE_ENDPOINT_ID}")
|
|
281
|
+
result = await globus_client.submit_inference(
|
|
282
|
+
messages=messages,
|
|
283
|
+
temperature=temperature,
|
|
284
|
+
max_tokens=max_tokens,
|
|
285
|
+
model=model,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
if "error" in result:
|
|
289
|
+
error_msg = result.get("error", "Unknown error")
|
|
290
|
+
error_type = result.get("error_type", "UnknownError")
|
|
291
|
+
if error_type == "AuthenticationError":
|
|
292
|
+
raise HTTPException(
|
|
293
|
+
status_code=401,
|
|
294
|
+
detail=f"Globus Compute authentication required: {error_msg}",
|
|
295
|
+
)
|
|
296
|
+
raise HTTPException(status_code=503, detail=f"HPC inference failed: {error_msg}")
|
|
297
|
+
|
|
298
|
+
logger.info("Batch inference completed successfully")
|
|
299
|
+
|
|
300
|
+
if stream:
|
|
301
|
+
# Caller wants streaming but relay is unavailable.
|
|
302
|
+
# Simulate streaming by splitting the complete response into word chunks.
|
|
303
|
+
return _convert_json_to_sse_stream(result)
|
|
304
|
+
return result
|
|
305
|
+
|
|
306
|
+
except HTTPException:
|
|
307
|
+
raise
|
|
308
|
+
except Exception as e:
|
|
309
|
+
logger.error(f"Globus Compute routing error: {e}", exc_info=True)
|
|
310
|
+
raise HTTPException(status_code=500, detail=f"Internal gateway error: {str(e)}") from e
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
async def _route_via_globus_compute_streaming(model, messages, temperature, max_tokens):
|
|
314
|
+
"""
|
|
315
|
+
True streaming from the HPC cluster via the WebSocket relay.
|
|
316
|
+
|
|
317
|
+
1. Submit a streaming job to Globus Compute → returns channel_id immediately
|
|
318
|
+
2. Connect to relay as consumer on that channel
|
|
319
|
+
3. Receive tokens in real-time, convert to SSE, yield to caller
|
|
320
|
+
|
|
321
|
+
The remote function on the HPC side connects to the relay as producer and
|
|
322
|
+
streams tokens through it. The encryption key (RELAY_ENCRYPTION_KEY) is read
|
|
323
|
+
from the endpoint's environment — it never travels over the Globus AMQP channel.
|
|
324
|
+
"""
|
|
325
|
+
from websockets.asyncio.client import connect as ws_connect
|
|
326
|
+
|
|
327
|
+
result = await globus_client.submit_streaming_inference(
|
|
328
|
+
messages=messages,
|
|
329
|
+
temperature=temperature,
|
|
330
|
+
max_tokens=max_tokens,
|
|
331
|
+
model=model,
|
|
332
|
+
relay_url=RELAY_URL,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
if "error" in result:
|
|
336
|
+
error_msg = result.get("error", "Unknown error")
|
|
337
|
+
error_type = result.get("error_type", "UnknownError")
|
|
338
|
+
if error_type == "AuthenticationError":
|
|
339
|
+
raise HTTPException(
|
|
340
|
+
status_code=401,
|
|
341
|
+
detail=f"Globus Compute authentication required: {error_msg}",
|
|
342
|
+
)
|
|
343
|
+
raise HTTPException(status_code=503, detail=f"HPC streaming failed: {error_msg}")
|
|
344
|
+
|
|
345
|
+
channel_id = result["channel_id"]
|
|
346
|
+
logger.info(f"Relay streaming: channel={channel_id[:8]}, relay={RELAY_URL}")
|
|
347
|
+
|
|
348
|
+
async def sse_generator():
|
|
349
|
+
"""Connect to relay and yield SSE events as tokens arrive."""
|
|
350
|
+
try:
|
|
351
|
+
relay_consume_url = f"{RELAY_URL}/consume/{channel_id}"
|
|
352
|
+
async with ws_connect(relay_consume_url) as ws:
|
|
353
|
+
# Post-handshake auth: send secret as the first message, not in the URL.
|
|
354
|
+
# This way it never appears in HTTP access logs or proxied headers.
|
|
355
|
+
if RELAY_SECRET:
|
|
356
|
+
await ws.send(json.dumps({"type": "auth", "secret": RELAY_SECRET}))
|
|
357
|
+
|
|
358
|
+
async for msg_str in ws:
|
|
359
|
+
# E2E decryption: if RELAY_ENCRYPTION_KEY is set, the producer
|
|
360
|
+
# encrypted each token payload before sending. decrypt_message()
|
|
361
|
+
# unwraps the {"type":"enc","d":"..."} envelope and returns the
|
|
362
|
+
# original plaintext JSON. If no key is configured, passthrough.
|
|
363
|
+
if RELAY_ENCRYPTION_KEY:
|
|
364
|
+
msg_str = decrypt_message(RELAY_ENCRYPTION_KEY, msg_str)
|
|
365
|
+
|
|
366
|
+
msg = json.loads(msg_str)
|
|
367
|
+
|
|
368
|
+
if msg["type"] == "token":
|
|
369
|
+
chunk = {
|
|
370
|
+
"choices": [{"index": 0, "delta": {"content": msg["content"]}}],
|
|
371
|
+
}
|
|
372
|
+
yield f"data: {json.dumps(chunk)}\n\n"
|
|
373
|
+
|
|
374
|
+
elif msg["type"] == "done":
|
|
375
|
+
usage = msg.get("usage", {})
|
|
376
|
+
if usage:
|
|
377
|
+
final_chunk = {
|
|
378
|
+
"choices": [
|
|
379
|
+
{"index": 0, "delta": {}, "finish_reason": "stop"}
|
|
380
|
+
],
|
|
381
|
+
"usage": usage,
|
|
382
|
+
}
|
|
383
|
+
yield f"data: {json.dumps(final_chunk)}\n\n"
|
|
384
|
+
yield "data: [DONE]\n\n"
|
|
385
|
+
break
|
|
386
|
+
|
|
387
|
+
elif msg["type"] == "error":
|
|
388
|
+
logger.error(
|
|
389
|
+
f"Relay error on channel {channel_id[:8]}: {msg.get('message')}"
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
except Exception as e:
|
|
393
|
+
logger.error(f"Relay connection failed: {e}", exc_info=True)
|
|
394
|
+
error_chunk = {"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}
|
|
395
|
+
yield f"data: {json.dumps(error_chunk)}\n\n"
|
|
396
|
+
yield "data: [DONE]\n\n"
|
|
397
|
+
|
|
398
|
+
return StreamingResponse(sse_generator(), media_type="text/event-stream")
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
async def _route_via_direct(model, messages, temperature, max_tokens, stream):
|
|
402
|
+
"""
|
|
403
|
+
Route directly to a vLLM server — no Globus Compute.
|
|
404
|
+
|
|
405
|
+
Used when USE_GLOBUS_COMPUTE=false (e.g., SSH tunnel is already open and
|
|
406
|
+
you want to skip the Globus auth layer). LAKESHORE_VLLM_ENDPOINT should
|
|
407
|
+
point to the tunnel's local end.
|
|
408
|
+
"""
|
|
409
|
+
if max_tokens is None:
|
|
410
|
+
max_tokens = 2048 # safe default when no model registry is available
|
|
411
|
+
|
|
412
|
+
payload = {
|
|
413
|
+
"model": model,
|
|
414
|
+
"messages": messages,
|
|
415
|
+
"temperature": temperature,
|
|
416
|
+
"max_tokens": max_tokens,
|
|
417
|
+
"stream": stream,
|
|
418
|
+
}
|
|
419
|
+
target_url = f"{LAKESHORE_VLLM_ENDPOINT}/v1/chat/completions"
|
|
420
|
+
logger.info(f"Direct vLLM request: {target_url}")
|
|
421
|
+
|
|
422
|
+
try:
|
|
423
|
+
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
424
|
+
if stream:
|
|
425
|
+
async with client.stream("POST", target_url, json=payload) as response:
|
|
426
|
+
if response.status_code != 200:
|
|
427
|
+
error_text = await response.aread()
|
|
428
|
+
raise HTTPException(
|
|
429
|
+
status_code=response.status_code,
|
|
430
|
+
detail=f"vLLM error: {error_text.decode()}",
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
async def stream_generator():
|
|
434
|
+
async for line in response.aiter_lines():
|
|
435
|
+
if line.strip():
|
|
436
|
+
yield line + "\n"
|
|
437
|
+
|
|
438
|
+
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
|
439
|
+
else:
|
|
440
|
+
response = await client.post(target_url, json=payload)
|
|
441
|
+
if response.status_code != 200:
|
|
442
|
+
raise HTTPException(
|
|
443
|
+
status_code=response.status_code,
|
|
444
|
+
detail=f"vLLM error: {response.text}",
|
|
445
|
+
)
|
|
446
|
+
return response.json()
|
|
447
|
+
|
|
448
|
+
except httpx.ConnectError as e:
|
|
449
|
+
raise HTTPException(
|
|
450
|
+
status_code=503,
|
|
451
|
+
detail=f"Cannot connect to vLLM. Is the tunnel running? Error: {str(e)}",
|
|
452
|
+
) from e
|
|
453
|
+
except httpx.TimeoutException as e:
|
|
454
|
+
raise HTTPException(status_code=504, detail="vLLM request timed out") from e
|
|
455
|
+
except HTTPException:
|
|
456
|
+
raise
|
|
457
|
+
except Exception as e:
|
|
458
|
+
raise HTTPException(status_code=500, detail=f"Internal gateway error: {str(e)}") from e
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def _convert_json_to_sse_stream(json_response: dict):
|
|
462
|
+
"""
|
|
463
|
+
Simulate streaming by splitting a complete response into word-sized SSE chunks.
|
|
464
|
+
|
|
465
|
+
Globus Compute is batch-only — the remote function returns the full response
|
|
466
|
+
at once. When the caller requested `stream=true` but the relay is unavailable,
|
|
467
|
+
we still return an SSE response to keep the API contract, yielding a few words
|
|
468
|
+
at a time with small delays to create a natural typing effect.
|
|
469
|
+
|
|
470
|
+
Splitting on words rather than characters avoids cutting through multi-byte
|
|
471
|
+
Unicode sequences and keeps each chunk visually coherent.
|
|
472
|
+
"""
|
|
473
|
+
words_per_chunk = 2 # 2 words → smooth appearance without excessive events
|
|
474
|
+
delay_between_chunks = 0.05 # 50 ms → ~40 words/second reading pace
|
|
475
|
+
|
|
476
|
+
async def sse_generator():
|
|
477
|
+
choices = json_response.get("choices", [])
|
|
478
|
+
if not choices:
|
|
479
|
+
yield "data: [DONE]\n\n"
|
|
480
|
+
return
|
|
481
|
+
|
|
482
|
+
choice = choices[0]
|
|
483
|
+
message = choice.get("message", {})
|
|
484
|
+
content = message.get("content", "")
|
|
485
|
+
role = message.get("role", "assistant")
|
|
486
|
+
|
|
487
|
+
chunk_base = {
|
|
488
|
+
"id": json_response.get("id", ""),
|
|
489
|
+
"object": "chat.completion.chunk",
|
|
490
|
+
"created": json_response.get("created", 0),
|
|
491
|
+
"model": json_response.get("model", ""),
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
# First chunk carries the role (OpenAI protocol requires this)
|
|
495
|
+
if role:
|
|
496
|
+
chunk = {
|
|
497
|
+
**chunk_base,
|
|
498
|
+
"choices": [
|
|
499
|
+
{"index": 0, "delta": {"role": role, "content": ""}, "finish_reason": None}
|
|
500
|
+
],
|
|
501
|
+
}
|
|
502
|
+
yield f"data: {json.dumps(chunk)}\n\n"
|
|
503
|
+
|
|
504
|
+
# Content chunks: group into small word batches
|
|
505
|
+
if content:
|
|
506
|
+
words = content.split(" ")
|
|
507
|
+
for i in range(0, len(words), words_per_chunk):
|
|
508
|
+
word_group = words[i : i + words_per_chunk]
|
|
509
|
+
# Space before words except the very first chunk
|
|
510
|
+
text_chunk = " ".join(word_group) if i == 0 else " " + " ".join(word_group)
|
|
511
|
+
chunk = {
|
|
512
|
+
**chunk_base,
|
|
513
|
+
"choices": [
|
|
514
|
+
{"index": 0, "delta": {"content": text_chunk}, "finish_reason": None}
|
|
515
|
+
],
|
|
516
|
+
}
|
|
517
|
+
yield f"data: {json.dumps(chunk)}\n\n"
|
|
518
|
+
await asyncio.sleep(delay_between_chunks)
|
|
519
|
+
|
|
520
|
+
# Final chunk signals completion and includes usage stats
|
|
521
|
+
chunk = {
|
|
522
|
+
**chunk_base,
|
|
523
|
+
"choices": [
|
|
524
|
+
{"index": 0, "delta": {}, "finish_reason": choice.get("finish_reason", "stop")}
|
|
525
|
+
],
|
|
526
|
+
"usage": json_response.get("usage", {}),
|
|
527
|
+
}
|
|
528
|
+
yield f"data: {json.dumps(chunk)}\n\n"
|
|
529
|
+
yield "data: [DONE]\n\n"
|
|
530
|
+
|
|
531
|
+
return StreamingResponse(sse_generator(), media_type="text/event-stream")
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
# =========================================================================
|
|
535
|
+
# Standalone FastAPI app (used when running as a service)
|
|
536
|
+
# =========================================================================
|
|
537
|
+
app = FastAPI(
|
|
538
|
+
title="HPC Gateway",
|
|
539
|
+
description="OpenAI-compatible API gateway for HPC clusters via Globus Compute",
|
|
540
|
+
version="0.1.0",
|
|
541
|
+
lifespan=lifespan,
|
|
542
|
+
)
|
|
543
|
+
app.include_router(router)
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def main():
|
|
547
|
+
import uvicorn
|
|
548
|
+
|
|
549
|
+
uvicorn.run(app, host=PROXY_HOST, port=PROXY_PORT, log_level=LOG_LEVEL.lower())
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
if __name__ == "__main__":
|
|
553
|
+
main()
|