chutes 0.3.61rc22__py3-none-any.whl → 0.5.3rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chutes/_version.py +1 -1
- chutes/cfsv_v2 +0 -0
- chutes/cfsv_v3 +0 -0
- chutes/cfsv_wrapper.py +135 -11
- chutes/chute/base.py +4 -0
- chutes/chute/cord.py +171 -11
- chutes/chute/template/diffusion.py +2 -0
- chutes/chute/template/embedding.py +76 -94
- chutes/chute/template/helpers.py +78 -11
- chutes/chute/template/sglang.py +50 -12
- chutes/chute/template/vllm.py +92 -155
- chutes/chutes-cfsv.so +0 -0
- chutes/chutes-inspecto.so +0 -0
- chutes/chutes-logintercept.so +0 -0
- chutes/chutes-netnanny.so +0 -0
- chutes/chutes-runint.so +0 -0
- chutes/cli.py +4 -0
- chutes/entrypoint/_shared.py +58 -3
- chutes/entrypoint/build.py +3 -1
- chutes/entrypoint/deploy.py +1 -0
- chutes/entrypoint/login.py +158 -0
- chutes/entrypoint/run.py +558 -259
- chutes/entrypoint/verify.py +190 -0
- chutes/envdump/envdump.so +0 -0
- chutes/image/__init__.py +4 -2
- chutes/image/standard/sglang.py +1 -1
- chutes/image/standard/vllm.py +1 -1
- chutes/inspecto.py +3 -3
- chutes/util/hf.py +211 -0
- {chutes-0.3.61rc22.dist-info → chutes-0.5.3rc1.dist-info}/METADATA +9 -6
- {chutes-0.3.61rc22.dist-info → chutes-0.5.3rc1.dist-info}/RECORD +35 -30
- {chutes-0.3.61rc22.dist-info → chutes-0.5.3rc1.dist-info}/entry_points.txt +0 -1
- chutes/chute/template/tei.py +0 -177
- chutes/image/standard/tei.py +0 -8
- {chutes-0.3.61rc22.dist-info → chutes-0.5.3rc1.dist-info}/LICENSE +0 -0
- {chutes-0.3.61rc22.dist-info → chutes-0.5.3rc1.dist-info}/WHEEL +0 -0
- {chutes-0.3.61rc22.dist-info → chutes-0.5.3rc1.dist-info}/top_level.txt +0 -0
chutes/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
version = "0.3.
|
|
1
|
+
version = "0.5.3.rc1"
|
chutes/cfsv_v2
ADDED
|
Binary file
|
chutes/cfsv_v3
ADDED
|
Binary file
|
chutes/cfsv_wrapper.py
CHANGED
|
@@ -1,16 +1,140 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
from
|
|
2
|
+
import ctypes
|
|
3
|
+
import asyncio
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from fastapi import Request
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
8
|
+
class CFSVWrapper:
|
|
9
|
+
def __init__(
|
|
10
|
+
self, lib_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "chutes-cfsv.so")
|
|
11
|
+
):
|
|
12
|
+
self.lib = ctypes.CDLL(lib_path)
|
|
13
13
|
|
|
14
|
+
# cfsv_challenge(base_path, salt, sparse, index_file, exclude_path, result_buf, result_buf_size)
|
|
15
|
+
self.lib.cfsv_challenge.argtypes = [
|
|
16
|
+
ctypes.c_char_p, # base_path
|
|
17
|
+
ctypes.c_char_p, # salt
|
|
18
|
+
ctypes.c_int, # sparse
|
|
19
|
+
ctypes.c_char_p, # index_file
|
|
20
|
+
ctypes.c_char_p, # exclude_path
|
|
21
|
+
ctypes.c_char_p, # result_buf
|
|
22
|
+
ctypes.c_size_t, # result_buf_size
|
|
23
|
+
]
|
|
24
|
+
self.lib.cfsv_challenge.restype = ctypes.c_int
|
|
14
25
|
|
|
15
|
-
|
|
16
|
-
|
|
26
|
+
# cfsv_sizetest(test_dir, size_gib)
|
|
27
|
+
self.lib.cfsv_sizetest.argtypes = [
|
|
28
|
+
ctypes.c_char_p, # test_dir
|
|
29
|
+
ctypes.c_size_t, # size_gib
|
|
30
|
+
]
|
|
31
|
+
self.lib.cfsv_sizetest.restype = ctypes.c_int
|
|
32
|
+
|
|
33
|
+
# cfsv_version()
|
|
34
|
+
self.lib.cfsv_version.argtypes = []
|
|
35
|
+
self.lib.cfsv_version.restype = ctypes.c_char_p
|
|
36
|
+
|
|
37
|
+
def challenge(
|
|
38
|
+
self,
|
|
39
|
+
salt,
|
|
40
|
+
mode="full",
|
|
41
|
+
base_path="/",
|
|
42
|
+
index_file="/etc/chutesfs.index",
|
|
43
|
+
exclude_path="/app/chute.py",
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Compute filesystem challenge hash.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
salt: Challenge salt from validator
|
|
50
|
+
mode: "sparse" or "full" (default: "full")
|
|
51
|
+
base_path: Root path for file scanning (default: "/")
|
|
52
|
+
index_file: Path to encrypted index file (default: "/etc/chutesfs.index")
|
|
53
|
+
exclude_path: Path to exclude from hashing (default: "/app/chute.py")
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Hex string hash on success, None on failure
|
|
57
|
+
"""
|
|
58
|
+
sparse = 1 if mode == "sparse" else 0
|
|
59
|
+
result_buf = ctypes.create_string_buffer(65)
|
|
60
|
+
ret = self.lib.cfsv_challenge(
|
|
61
|
+
base_path.encode() if isinstance(base_path, str) else base_path,
|
|
62
|
+
salt.encode() if isinstance(salt, str) else salt,
|
|
63
|
+
sparse,
|
|
64
|
+
index_file.encode() if isinstance(index_file, str) else index_file,
|
|
65
|
+
exclude_path.encode() if isinstance(exclude_path, str) else exclude_path,
|
|
66
|
+
result_buf,
|
|
67
|
+
65,
|
|
68
|
+
)
|
|
69
|
+
if ret == 0:
|
|
70
|
+
return result_buf.value.decode("utf-8")
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
def sizetest(self, test_dir, size_gib):
|
|
74
|
+
"""
|
|
75
|
+
Test filesystem capacity and integrity.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
test_dir: Directory to test in
|
|
79
|
+
size_gib: Size in GiB to test
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
True on success, False on failure
|
|
83
|
+
"""
|
|
84
|
+
ret = self.lib.cfsv_sizetest(
|
|
85
|
+
test_dir.encode() if isinstance(test_dir, str) else test_dir,
|
|
86
|
+
size_gib,
|
|
87
|
+
)
|
|
88
|
+
return ret == 0
|
|
89
|
+
|
|
90
|
+
def version(self):
|
|
91
|
+
"""
|
|
92
|
+
Get library version.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Version string
|
|
96
|
+
"""
|
|
97
|
+
result = self.lib.cfsv_version()
|
|
98
|
+
if result:
|
|
99
|
+
return result.decode("utf-8")
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@lru_cache(maxsize=1)
|
|
104
|
+
def get_cfsv():
|
|
105
|
+
"""Lazily initialize CFSV wrapper (only works on Linux)."""
|
|
106
|
+
return CFSVWrapper()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
async def handle_challenge(request: Request):
|
|
110
|
+
loop = asyncio.get_event_loop()
|
|
111
|
+
salt = request.state.decrypted["salt"]
|
|
112
|
+
mode = request.state.decrypted.get("mode", "full")
|
|
113
|
+
exclude_path = request.state.decrypted.get("exclude_path", "/app/chute.py")
|
|
114
|
+
result = await loop.run_in_executor(
|
|
115
|
+
None,
|
|
116
|
+
get_cfsv().challenge,
|
|
117
|
+
salt,
|
|
118
|
+
mode,
|
|
119
|
+
"/",
|
|
120
|
+
"/etc/chutesfs.index",
|
|
121
|
+
exclude_path,
|
|
122
|
+
)
|
|
123
|
+
return {"result": result}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def handle_sizetest(request: Request):
|
|
127
|
+
loop = asyncio.get_event_loop()
|
|
128
|
+
test_dir = request.state.decrypted.get("test_dir", "/tmp")
|
|
129
|
+
size_gib = request.state.decrypted.get("size_gib", 10)
|
|
130
|
+
result = await loop.run_in_executor(
|
|
131
|
+
None,
|
|
132
|
+
get_cfsv().sizetest,
|
|
133
|
+
test_dir,
|
|
134
|
+
size_gib,
|
|
135
|
+
)
|
|
136
|
+
return {"result": result}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
async def handle_version(request: Request):
|
|
140
|
+
return {"result": get_cfsv().version()}
|
chutes/chute/base.py
CHANGED
|
@@ -45,6 +45,8 @@ class Chute(FastAPI):
|
|
|
45
45
|
scaling_threshold: float = 0.75,
|
|
46
46
|
allow_external_egress: bool = False,
|
|
47
47
|
encrypted_fs: bool = False,
|
|
48
|
+
passthrough_headers: dict = {},
|
|
49
|
+
tee: bool = False,
|
|
48
50
|
**kwargs,
|
|
49
51
|
):
|
|
50
52
|
from chutes.chute.cord import Cord
|
|
@@ -71,8 +73,10 @@ class Chute(FastAPI):
|
|
|
71
73
|
self.shutdown_after_seconds = shutdown_after_seconds
|
|
72
74
|
self.allow_external_egress = allow_external_egress
|
|
73
75
|
self.encrypted_fs = encrypted_fs
|
|
76
|
+
self.passthrough_headers = passthrough_headers
|
|
74
77
|
self.docs_url = None
|
|
75
78
|
self.redoc_url = None
|
|
79
|
+
self.tee = tee
|
|
76
80
|
|
|
77
81
|
@property
|
|
78
82
|
def name(self):
|
chutes/chute/cord.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import aiohttp
|
|
3
2
|
import re
|
|
4
|
-
import
|
|
3
|
+
import uuid
|
|
5
4
|
import gzip
|
|
6
5
|
import time
|
|
7
|
-
import orjson as json
|
|
8
|
-
import fickling
|
|
9
6
|
import pickle
|
|
10
7
|
import base64
|
|
8
|
+
import aiohttp
|
|
9
|
+
import asyncio
|
|
10
|
+
import backoff
|
|
11
|
+
import fickling
|
|
12
|
+
import orjson as json
|
|
11
13
|
from pydantic import ValidationError
|
|
12
14
|
from typing import Optional, Dict, Any
|
|
13
15
|
from fastapi import Request, HTTPException, status
|
|
@@ -44,6 +46,7 @@ class Cord:
|
|
|
44
46
|
minimal_input_schema: Optional[Any] = None,
|
|
45
47
|
output_content_type: Optional[str] = None,
|
|
46
48
|
output_schema: Optional[Dict] = None,
|
|
49
|
+
sglang_passthrough: bool = False,
|
|
47
50
|
**session_kwargs,
|
|
48
51
|
):
|
|
49
52
|
"""
|
|
@@ -67,6 +70,7 @@ class Cord:
|
|
|
67
70
|
self._session_kwargs = session_kwargs
|
|
68
71
|
self._provision_timeout = provision_timeout
|
|
69
72
|
self._config = None
|
|
73
|
+
self._sglang_passthrough = sglang_passthrough
|
|
70
74
|
self.input_models = (
|
|
71
75
|
[input_schema] if input_schema and hasattr(input_schema, "__fields__") else None
|
|
72
76
|
)
|
|
@@ -156,6 +160,9 @@ class Cord:
|
|
|
156
160
|
raise InvalidPath(path)
|
|
157
161
|
self._public_api_path = path
|
|
158
162
|
|
|
163
|
+
def _is_sglang_passthrough(self) -> bool:
|
|
164
|
+
return self._passthrough and self._sglang_passthrough
|
|
165
|
+
|
|
159
166
|
@asynccontextmanager
|
|
160
167
|
async def _local_call_base(self, *args, **kwargs):
|
|
161
168
|
"""
|
|
@@ -268,15 +275,28 @@ class Cord:
|
|
|
268
275
|
yield data["result"]
|
|
269
276
|
|
|
270
277
|
@asynccontextmanager
|
|
271
|
-
async def _passthrough_call(self, **kwargs):
|
|
278
|
+
async def _passthrough_call(self, request: Request, **kwargs):
|
|
272
279
|
"""
|
|
273
280
|
Call a passthrough endpoint.
|
|
274
281
|
"""
|
|
275
282
|
logger.debug(
|
|
276
283
|
f"Received passthrough call, passing along to {self.passthrough_path} via {self._method}"
|
|
277
284
|
)
|
|
285
|
+
headers = kwargs.pop("headers", {}) or {}
|
|
286
|
+
if self._app.passthrough_headers:
|
|
287
|
+
headers.update(self._app.passthrough_headers)
|
|
288
|
+
kwargs["headers"] = headers
|
|
289
|
+
|
|
290
|
+
# Set (if needed) timeout.
|
|
291
|
+
timeout = None
|
|
292
|
+
if self._is_sglang_passthrough():
|
|
293
|
+
timeout = aiohttp.ClientTimeout(connect=5.0, total=None)
|
|
294
|
+
else:
|
|
295
|
+
total_timeout = kwargs.pop("timeout", 1800)
|
|
296
|
+
timeout = aiohttp.ClientTimeout(connect=5.0, total=total_timeout)
|
|
297
|
+
|
|
278
298
|
async with aiohttp.ClientSession(
|
|
279
|
-
timeout=
|
|
299
|
+
timeout=timeout,
|
|
280
300
|
read_bufsize=8 * 1024 * 1024,
|
|
281
301
|
base_url=f"http://127.0.0.1:{self._passthrough_port or 8000}",
|
|
282
302
|
) as session:
|
|
@@ -300,9 +320,88 @@ class Cord:
|
|
|
300
320
|
function=self._func.__name__,
|
|
301
321
|
).set_to_current_time()
|
|
302
322
|
encrypt = getattr(request.state, "_encrypt", None)
|
|
323
|
+
|
|
303
324
|
try:
|
|
304
325
|
if self._passthrough:
|
|
305
|
-
|
|
326
|
+
rid = getattr(request.state, "sglang_rid", None)
|
|
327
|
+
|
|
328
|
+
# SGLang passthrough: run upstream call and disconnect watcher in parallel
|
|
329
|
+
if self._is_sglang_passthrough():
|
|
330
|
+
|
|
331
|
+
async def call_upstream():
|
|
332
|
+
async with self._passthrough_call(request, **kwargs) as response:
|
|
333
|
+
if not 200 <= response.status < 300:
|
|
334
|
+
try:
|
|
335
|
+
error_detail = await response.json()
|
|
336
|
+
except Exception:
|
|
337
|
+
error_detail = await response.text()
|
|
338
|
+
logger.error(
|
|
339
|
+
f"Failed to generate response from func={self._func.__name__}: {response.status=} -> {error_detail}"
|
|
340
|
+
)
|
|
341
|
+
raise HTTPException(
|
|
342
|
+
status_code=response.status,
|
|
343
|
+
detail=error_detail,
|
|
344
|
+
)
|
|
345
|
+
if encrypt:
|
|
346
|
+
raw = await response.read()
|
|
347
|
+
return {"json": encrypt(raw)}
|
|
348
|
+
return await response.json()
|
|
349
|
+
|
|
350
|
+
async def watch_disconnect():
|
|
351
|
+
try:
|
|
352
|
+
while True:
|
|
353
|
+
message = await request._receive()
|
|
354
|
+
if message.get("type") == "http.disconnect":
|
|
355
|
+
logger.info(
|
|
356
|
+
f"[{self._func.__name__}] Received http.disconnect, "
|
|
357
|
+
f"aborting upstream SGLang request (rid={rid})"
|
|
358
|
+
)
|
|
359
|
+
try:
|
|
360
|
+
await self._abort_sglang_request(rid)
|
|
361
|
+
except Exception as exc:
|
|
362
|
+
logger.warning(
|
|
363
|
+
f"Error while sending abort_request for rid={rid}: {exc}"
|
|
364
|
+
)
|
|
365
|
+
raise HTTPException(
|
|
366
|
+
status_code=499,
|
|
367
|
+
detail="Client disconnected during SGLang request",
|
|
368
|
+
)
|
|
369
|
+
except HTTPException:
|
|
370
|
+
raise
|
|
371
|
+
except Exception as exc:
|
|
372
|
+
logger.warning(f"watch_disconnect error: {exc}")
|
|
373
|
+
raise HTTPException(
|
|
374
|
+
status_code=499,
|
|
375
|
+
detail="Client disconnected during SGLang request",
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
upstream_task = asyncio.create_task(call_upstream())
|
|
379
|
+
watcher_task = asyncio.create_task(watch_disconnect())
|
|
380
|
+
|
|
381
|
+
done, pending = await asyncio.wait(
|
|
382
|
+
{upstream_task, watcher_task},
|
|
383
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
384
|
+
)
|
|
385
|
+
for task in pending:
|
|
386
|
+
task.cancel()
|
|
387
|
+
|
|
388
|
+
if watcher_task in done:
|
|
389
|
+
exc = watcher_task.exception()
|
|
390
|
+
|
|
391
|
+
if exc:
|
|
392
|
+
raise exc
|
|
393
|
+
raise HTTPException(
|
|
394
|
+
status_code=499,
|
|
395
|
+
detail="Client disconnected during SGLang request",
|
|
396
|
+
)
|
|
397
|
+
result = upstream_task.result()
|
|
398
|
+
logger.success(
|
|
399
|
+
f"Completed request [{self._func.__name__} passthrough={self._passthrough}] "
|
|
400
|
+
f"in {time.time() - started_at} seconds"
|
|
401
|
+
)
|
|
402
|
+
return result
|
|
403
|
+
|
|
404
|
+
async with self._passthrough_call(request, **kwargs) as response:
|
|
306
405
|
if not 200 <= response.status < 300:
|
|
307
406
|
try:
|
|
308
407
|
error_detail = await response.json()
|
|
@@ -316,15 +415,18 @@ class Cord:
|
|
|
316
415
|
detail=error_detail,
|
|
317
416
|
)
|
|
318
417
|
logger.success(
|
|
319
|
-
f"Completed request [{self._func.__name__} passthrough={self._passthrough}]
|
|
418
|
+
f"Completed request [{self._func.__name__} passthrough={self._passthrough}] "
|
|
419
|
+
f"in {time.time() - started_at} seconds"
|
|
320
420
|
)
|
|
321
421
|
if encrypt:
|
|
322
422
|
return {"json": encrypt(await response.read())}
|
|
323
423
|
return await response.json()
|
|
324
424
|
|
|
325
|
-
|
|
425
|
+
# Non-passthrough call (local Python function)
|
|
426
|
+
response = await asyncio.wait_for(self._func(self._app, *args, **kwargs), 1800)
|
|
326
427
|
logger.success(
|
|
327
|
-
f"Completed request [{self._func.__name__} passthrough={self._passthrough}]
|
|
428
|
+
f"Completed request [{self._func.__name__} passthrough={self._passthrough}] "
|
|
429
|
+
f"in {time.time() - started_at} seconds"
|
|
328
430
|
)
|
|
329
431
|
if hasattr(response, "body"):
|
|
330
432
|
if encrypt:
|
|
@@ -340,6 +442,19 @@ class Cord:
|
|
|
340
442
|
if encrypt:
|
|
341
443
|
return {"json": encrypt(json.dumps(response))}
|
|
342
444
|
return response
|
|
445
|
+
except asyncio.CancelledError:
|
|
446
|
+
rid = getattr(request.state, "sglang_rid", None)
|
|
447
|
+
if self._is_sglang_passthrough():
|
|
448
|
+
logger.info(
|
|
449
|
+
f"Non-stream request for {self._func.__name__} cancelled "
|
|
450
|
+
f"(likely client disconnect), aborting SGLang rid={rid}"
|
|
451
|
+
)
|
|
452
|
+
try:
|
|
453
|
+
await self._abort_sglang_request(rid)
|
|
454
|
+
except Exception as exc:
|
|
455
|
+
logger.warning(f"Error while sending abort_request for rid={rid}: {exc}")
|
|
456
|
+
status = 499
|
|
457
|
+
raise
|
|
343
458
|
except Exception as exc:
|
|
344
459
|
logger.error(f"Error performing non-streamed call: {exc}")
|
|
345
460
|
status = 500
|
|
@@ -371,7 +486,9 @@ class Cord:
|
|
|
371
486
|
encrypt = getattr(request.state, "_encrypt", None)
|
|
372
487
|
try:
|
|
373
488
|
if self._passthrough:
|
|
374
|
-
|
|
489
|
+
rid = getattr(request.state, "sglang_rid", None)
|
|
490
|
+
|
|
491
|
+
async with self._passthrough_call(request, **kwargs) as response:
|
|
375
492
|
if not 200 <= response.status < 300:
|
|
376
493
|
try:
|
|
377
494
|
error_detail = await response.json()
|
|
@@ -385,6 +502,13 @@ class Cord:
|
|
|
385
502
|
detail=error_detail,
|
|
386
503
|
)
|
|
387
504
|
async for content in response.content:
|
|
505
|
+
if await request.is_disconnected():
|
|
506
|
+
logger.info(
|
|
507
|
+
f"Client disconnected for {self._func.__name__}, aborting upstream (rid={rid})"
|
|
508
|
+
)
|
|
509
|
+
await self._abort_sglang_request(rid)
|
|
510
|
+
break
|
|
511
|
+
|
|
388
512
|
if encrypt:
|
|
389
513
|
yield encrypt(content) + "\n"
|
|
390
514
|
else:
|
|
@@ -402,6 +526,20 @@ class Cord:
|
|
|
402
526
|
logger.success(
|
|
403
527
|
f"Completed request [{self._func.__name__}] in {time.time() - started_at} seconds"
|
|
404
528
|
)
|
|
529
|
+
except asyncio.CancelledError:
|
|
530
|
+
rid = getattr(request.state, "sglang_rid", None)
|
|
531
|
+
if self._is_sglang_passthrough():
|
|
532
|
+
logger.info(
|
|
533
|
+
f"Streaming cancelled for {self._func.__name__} "
|
|
534
|
+
f"(likely client disconnect), aborting SGLang rid={rid}"
|
|
535
|
+
)
|
|
536
|
+
try:
|
|
537
|
+
await self._abort_sglang_request(rid)
|
|
538
|
+
except Exception as exc:
|
|
539
|
+
logger.warning(f"Error while sending abort_request for rid={rid}: {exc}")
|
|
540
|
+
status = 499
|
|
541
|
+
raise
|
|
542
|
+
|
|
405
543
|
except Exception as exc:
|
|
406
544
|
logger.error(f"Error performing stream call: {exc}")
|
|
407
545
|
status = 500
|
|
@@ -418,6 +556,21 @@ class Cord:
|
|
|
418
556
|
status=status,
|
|
419
557
|
).observe(time.time() - started_at)
|
|
420
558
|
|
|
559
|
+
async def _abort_sglang_request(self, rid: Optional[str]):
|
|
560
|
+
if not rid or not self._is_sglang_passthrough():
|
|
561
|
+
return
|
|
562
|
+
try:
|
|
563
|
+
async with aiohttp.ClientSession(
|
|
564
|
+
timeout=aiohttp.ClientTimeout(connect=5.0, total=15.0),
|
|
565
|
+
base_url=f"http://127.0.0.1:{self._passthrough_port or 8000}",
|
|
566
|
+
headers=self._app.passthrough_headers or {},
|
|
567
|
+
) as session:
|
|
568
|
+
logger.warning(f"Aborting SGLang request {rid=}")
|
|
569
|
+
await session.post("/abort_request", json={"rid": rid})
|
|
570
|
+
logger.success(f"Sent SGLang abort_request for rid={rid}")
|
|
571
|
+
except Exception as exc:
|
|
572
|
+
logger.warning(f"Failed to send abort_request for rid={rid}: {exc}")
|
|
573
|
+
|
|
421
574
|
async def _request_handler(self, request: Request):
|
|
422
575
|
"""
|
|
423
576
|
Decode/deserialize incoming request and call the appropriate function.
|
|
@@ -448,6 +601,13 @@ class Cord:
|
|
|
448
601
|
else:
|
|
449
602
|
args = []
|
|
450
603
|
kwargs = {"json": request.state.decrypted} if request.state.decrypted else {}
|
|
604
|
+
|
|
605
|
+
# Set a custom request ID for SGLang passthroughs.
|
|
606
|
+
if self._is_sglang_passthrough() and isinstance(kwargs.get("json"), dict):
|
|
607
|
+
rid = uuid.uuid4().hex
|
|
608
|
+
kwargs["json"].setdefault("rid", rid)
|
|
609
|
+
request.state.sglang_rid = rid
|
|
610
|
+
|
|
451
611
|
if not self._passthrough:
|
|
452
612
|
if self.input_models and all([isinstance(args[idx], dict) for idx in range(len(args))]):
|
|
453
613
|
try:
|
|
@@ -152,6 +152,7 @@ def build_diffusion_chute(
|
|
|
152
152
|
max_instances: int = 1,
|
|
153
153
|
scaling_threshold: float = 0.75,
|
|
154
154
|
shutdown_after_seconds: int = 300,
|
|
155
|
+
tee: bool = False,
|
|
155
156
|
):
|
|
156
157
|
chute = Chute(
|
|
157
158
|
username=username,
|
|
@@ -165,6 +166,7 @@ def build_diffusion_chute(
|
|
|
165
166
|
shutdown_after_seconds=shutdown_after_seconds,
|
|
166
167
|
max_instances=max_instances,
|
|
167
168
|
scaling_threshold=scaling_threshold,
|
|
169
|
+
tee=tee,
|
|
168
170
|
)
|
|
169
171
|
|
|
170
172
|
@chute.on_startup()
|