chutes 0.3.61rc22__py3-none-any.whl → 0.5.3rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
chutes/_version.py CHANGED
@@ -1 +1 @@
1
- version = "0.3.61.rc22"
1
+ version = "0.5.3.rc1"
chutes/cfsv_v2 ADDED
Binary file
chutes/cfsv_v3 ADDED
Binary file
chutes/cfsv_wrapper.py CHANGED
@@ -1,16 +1,140 @@
1
1
  import os
2
- import sys
3
- import stat
4
- import subprocess
5
- from pathlib import Path
2
+ import ctypes
3
+ import asyncio
4
+ from functools import lru_cache
5
+ from fastapi import Request
6
6
 
7
7
 
8
- def main():
9
- binary_path = Path(__file__).parent / "cfsv"
10
- os.chmod(binary_path, os.stat(binary_path).st_mode | stat.S_IEXEC)
11
- result = subprocess.run([str(binary_path)] + sys.argv[1:])
12
- sys.exit(result.returncode)
8
+ class CFSVWrapper:
9
+ def __init__(
10
+ self, lib_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "chutes-cfsv.so")
11
+ ):
12
+ self.lib = ctypes.CDLL(lib_path)
13
13
 
14
+ # cfsv_challenge(base_path, salt, sparse, index_file, exclude_path, result_buf, result_buf_size)
15
+ self.lib.cfsv_challenge.argtypes = [
16
+ ctypes.c_char_p, # base_path
17
+ ctypes.c_char_p, # salt
18
+ ctypes.c_int, # sparse
19
+ ctypes.c_char_p, # index_file
20
+ ctypes.c_char_p, # exclude_path
21
+ ctypes.c_char_p, # result_buf
22
+ ctypes.c_size_t, # result_buf_size
23
+ ]
24
+ self.lib.cfsv_challenge.restype = ctypes.c_int
14
25
 
15
- if __name__ == "__main__":
16
- main()
26
+ # cfsv_sizetest(test_dir, size_gib)
27
+ self.lib.cfsv_sizetest.argtypes = [
28
+ ctypes.c_char_p, # test_dir
29
+ ctypes.c_size_t, # size_gib
30
+ ]
31
+ self.lib.cfsv_sizetest.restype = ctypes.c_int
32
+
33
+ # cfsv_version()
34
+ self.lib.cfsv_version.argtypes = []
35
+ self.lib.cfsv_version.restype = ctypes.c_char_p
36
+
37
+ def challenge(
38
+ self,
39
+ salt,
40
+ mode="full",
41
+ base_path="/",
42
+ index_file="/etc/chutesfs.index",
43
+ exclude_path="/app/chute.py",
44
+ ):
45
+ """
46
+ Compute filesystem challenge hash.
47
+
48
+ Args:
49
+ salt: Challenge salt from validator
50
+ mode: "sparse" or "full" (default: "full")
51
+ base_path: Root path for file scanning (default: "/")
52
+ index_file: Path to encrypted index file (default: "/etc/chutesfs.index")
53
+ exclude_path: Path to exclude from hashing (default: "/app/chute.py")
54
+
55
+ Returns:
56
+ Hex string hash on success, None on failure
57
+ """
58
+ sparse = 1 if mode == "sparse" else 0
59
+ result_buf = ctypes.create_string_buffer(65)
60
+ ret = self.lib.cfsv_challenge(
61
+ base_path.encode() if isinstance(base_path, str) else base_path,
62
+ salt.encode() if isinstance(salt, str) else salt,
63
+ sparse,
64
+ index_file.encode() if isinstance(index_file, str) else index_file,
65
+ exclude_path.encode() if isinstance(exclude_path, str) else exclude_path,
66
+ result_buf,
67
+ 65,
68
+ )
69
+ if ret == 0:
70
+ return result_buf.value.decode("utf-8")
71
+ return None
72
+
73
+ def sizetest(self, test_dir, size_gib):
74
+ """
75
+ Test filesystem capacity and integrity.
76
+
77
+ Args:
78
+ test_dir: Directory to test in
79
+ size_gib: Size in GiB to test
80
+
81
+ Returns:
82
+ True on success, False on failure
83
+ """
84
+ ret = self.lib.cfsv_sizetest(
85
+ test_dir.encode() if isinstance(test_dir, str) else test_dir,
86
+ size_gib,
87
+ )
88
+ return ret == 0
89
+
90
+ def version(self):
91
+ """
92
+ Get library version.
93
+
94
+ Returns:
95
+ Version string
96
+ """
97
+ result = self.lib.cfsv_version()
98
+ if result:
99
+ return result.decode("utf-8")
100
+ return None
101
+
102
+
103
+ @lru_cache(maxsize=1)
104
+ def get_cfsv():
105
+ """Lazily initialize CFSV wrapper (only works on Linux)."""
106
+ return CFSVWrapper()
107
+
108
+
109
+ async def handle_challenge(request: Request):
110
+ loop = asyncio.get_event_loop()
111
+ salt = request.state.decrypted["salt"]
112
+ mode = request.state.decrypted.get("mode", "full")
113
+ exclude_path = request.state.decrypted.get("exclude_path", "/app/chute.py")
114
+ result = await loop.run_in_executor(
115
+ None,
116
+ get_cfsv().challenge,
117
+ salt,
118
+ mode,
119
+ "/",
120
+ "/etc/chutesfs.index",
121
+ exclude_path,
122
+ )
123
+ return {"result": result}
124
+
125
+
126
+ async def handle_sizetest(request: Request):
127
+ loop = asyncio.get_event_loop()
128
+ test_dir = request.state.decrypted.get("test_dir", "/tmp")
129
+ size_gib = request.state.decrypted.get("size_gib", 10)
130
+ result = await loop.run_in_executor(
131
+ None,
132
+ get_cfsv().sizetest,
133
+ test_dir,
134
+ size_gib,
135
+ )
136
+ return {"result": result}
137
+
138
+
139
+ async def handle_version(request: Request):
140
+ return {"result": get_cfsv().version()}
chutes/chute/base.py CHANGED
@@ -45,6 +45,8 @@ class Chute(FastAPI):
45
45
  scaling_threshold: float = 0.75,
46
46
  allow_external_egress: bool = False,
47
47
  encrypted_fs: bool = False,
48
+ passthrough_headers: dict = {},
49
+ tee: bool = False,
48
50
  **kwargs,
49
51
  ):
50
52
  from chutes.chute.cord import Cord
@@ -71,8 +73,10 @@ class Chute(FastAPI):
71
73
  self.shutdown_after_seconds = shutdown_after_seconds
72
74
  self.allow_external_egress = allow_external_egress
73
75
  self.encrypted_fs = encrypted_fs
76
+ self.passthrough_headers = passthrough_headers
74
77
  self.docs_url = None
75
78
  self.redoc_url = None
79
+ self.tee = tee
76
80
 
77
81
  @property
78
82
  def name(self):
chutes/chute/cord.py CHANGED
@@ -1,13 +1,15 @@
1
1
  import os
2
- import aiohttp
3
2
  import re
4
- import backoff
3
+ import uuid
5
4
  import gzip
6
5
  import time
7
- import orjson as json
8
- import fickling
9
6
  import pickle
10
7
  import base64
8
+ import aiohttp
9
+ import asyncio
10
+ import backoff
11
+ import fickling
12
+ import orjson as json
11
13
  from pydantic import ValidationError
12
14
  from typing import Optional, Dict, Any
13
15
  from fastapi import Request, HTTPException, status
@@ -44,6 +46,7 @@ class Cord:
44
46
  minimal_input_schema: Optional[Any] = None,
45
47
  output_content_type: Optional[str] = None,
46
48
  output_schema: Optional[Dict] = None,
49
+ sglang_passthrough: bool = False,
47
50
  **session_kwargs,
48
51
  ):
49
52
  """
@@ -67,6 +70,7 @@ class Cord:
67
70
  self._session_kwargs = session_kwargs
68
71
  self._provision_timeout = provision_timeout
69
72
  self._config = None
73
+ self._sglang_passthrough = sglang_passthrough
70
74
  self.input_models = (
71
75
  [input_schema] if input_schema and hasattr(input_schema, "__fields__") else None
72
76
  )
@@ -156,6 +160,9 @@ class Cord:
156
160
  raise InvalidPath(path)
157
161
  self._public_api_path = path
158
162
 
163
+ def _is_sglang_passthrough(self) -> bool:
164
+ return self._passthrough and self._sglang_passthrough
165
+
159
166
  @asynccontextmanager
160
167
  async def _local_call_base(self, *args, **kwargs):
161
168
  """
@@ -268,15 +275,28 @@ class Cord:
268
275
  yield data["result"]
269
276
 
270
277
  @asynccontextmanager
271
- async def _passthrough_call(self, **kwargs):
278
+ async def _passthrough_call(self, request: Request, **kwargs):
272
279
  """
273
280
  Call a passthrough endpoint.
274
281
  """
275
282
  logger.debug(
276
283
  f"Received passthrough call, passing along to {self.passthrough_path} via {self._method}"
277
284
  )
285
+ headers = kwargs.pop("headers", {}) or {}
286
+ if self._app.passthrough_headers:
287
+ headers.update(self._app.passthrough_headers)
288
+ kwargs["headers"] = headers
289
+
290
+ # Set (if needed) timeout.
291
+ timeout = None
292
+ if self._is_sglang_passthrough():
293
+ timeout = aiohttp.ClientTimeout(connect=5.0, total=None)
294
+ else:
295
+ total_timeout = kwargs.pop("timeout", 1800)
296
+ timeout = aiohttp.ClientTimeout(connect=5.0, total=total_timeout)
297
+
278
298
  async with aiohttp.ClientSession(
279
- timeout=aiohttp.ClientTimeout(connect=5.0, total=900.0),
299
+ timeout=timeout,
280
300
  read_bufsize=8 * 1024 * 1024,
281
301
  base_url=f"http://127.0.0.1:{self._passthrough_port or 8000}",
282
302
  ) as session:
@@ -300,9 +320,88 @@ class Cord:
300
320
  function=self._func.__name__,
301
321
  ).set_to_current_time()
302
322
  encrypt = getattr(request.state, "_encrypt", None)
323
+
303
324
  try:
304
325
  if self._passthrough:
305
- async with self._passthrough_call(**kwargs) as response:
326
+ rid = getattr(request.state, "sglang_rid", None)
327
+
328
+ # SGLang passthrough: run upstream call and disconnect watcher in parallel
329
+ if self._is_sglang_passthrough():
330
+
331
+ async def call_upstream():
332
+ async with self._passthrough_call(request, **kwargs) as response:
333
+ if not 200 <= response.status < 300:
334
+ try:
335
+ error_detail = await response.json()
336
+ except Exception:
337
+ error_detail = await response.text()
338
+ logger.error(
339
+ f"Failed to generate response from func={self._func.__name__}: {response.status=} -> {error_detail}"
340
+ )
341
+ raise HTTPException(
342
+ status_code=response.status,
343
+ detail=error_detail,
344
+ )
345
+ if encrypt:
346
+ raw = await response.read()
347
+ return {"json": encrypt(raw)}
348
+ return await response.json()
349
+
350
+ async def watch_disconnect():
351
+ try:
352
+ while True:
353
+ message = await request._receive()
354
+ if message.get("type") == "http.disconnect":
355
+ logger.info(
356
+ f"[{self._func.__name__}] Received http.disconnect, "
357
+ f"aborting upstream SGLang request (rid={rid})"
358
+ )
359
+ try:
360
+ await self._abort_sglang_request(rid)
361
+ except Exception as exc:
362
+ logger.warning(
363
+ f"Error while sending abort_request for rid={rid}: {exc}"
364
+ )
365
+ raise HTTPException(
366
+ status_code=499,
367
+ detail="Client disconnected during SGLang request",
368
+ )
369
+ except HTTPException:
370
+ raise
371
+ except Exception as exc:
372
+ logger.warning(f"watch_disconnect error: {exc}")
373
+ raise HTTPException(
374
+ status_code=499,
375
+ detail="Client disconnected during SGLang request",
376
+ )
377
+
378
+ upstream_task = asyncio.create_task(call_upstream())
379
+ watcher_task = asyncio.create_task(watch_disconnect())
380
+
381
+ done, pending = await asyncio.wait(
382
+ {upstream_task, watcher_task},
383
+ return_when=asyncio.FIRST_COMPLETED,
384
+ )
385
+ for task in pending:
386
+ task.cancel()
387
+
388
+ if watcher_task in done:
389
+ exc = watcher_task.exception()
390
+
391
+ if exc:
392
+ raise exc
393
+ raise HTTPException(
394
+ status_code=499,
395
+ detail="Client disconnected during SGLang request",
396
+ )
397
+ result = upstream_task.result()
398
+ logger.success(
399
+ f"Completed request [{self._func.__name__} passthrough={self._passthrough}] "
400
+ f"in {time.time() - started_at} seconds"
401
+ )
402
+ return result
403
+
404
+ async with self._passthrough_call(request, **kwargs) as response:
306
405
  if not 200 <= response.status < 300:
307
406
  try:
308
407
  error_detail = await response.json()
@@ -316,15 +415,18 @@ class Cord:
316
415
  detail=error_detail,
317
416
  )
318
417
  logger.success(
319
- f"Completed request [{self._func.__name__} passthrough={self._passthrough}] in {time.time() - started_at} seconds"
418
+ f"Completed request [{self._func.__name__} passthrough={self._passthrough}] "
419
+ f"in {time.time() - started_at} seconds"
320
420
  )
321
421
  if encrypt:
322
422
  return {"json": encrypt(await response.read())}
323
423
  return await response.json()
324
424
 
325
- response = await self._func(self._app, *args, **kwargs)
425
+ # Non-passthrough call (local Python function)
426
+ response = await asyncio.wait_for(self._func(self._app, *args, **kwargs), 1800)
326
427
  logger.success(
327
- f"Completed request [{self._func.__name__} passthrough={self._passthrough}] in {time.time() - started_at} seconds"
428
+ f"Completed request [{self._func.__name__} passthrough={self._passthrough}] "
429
+ f"in {time.time() - started_at} seconds"
328
430
  )
329
431
  if hasattr(response, "body"):
330
432
  if encrypt:
@@ -340,6 +442,19 @@ class Cord:
340
442
  if encrypt:
341
443
  return {"json": encrypt(json.dumps(response))}
342
444
  return response
445
+ except asyncio.CancelledError:
446
+ rid = getattr(request.state, "sglang_rid", None)
447
+ if self._is_sglang_passthrough():
448
+ logger.info(
449
+ f"Non-stream request for {self._func.__name__} cancelled "
450
+ f"(likely client disconnect), aborting SGLang rid={rid}"
451
+ )
452
+ try:
453
+ await self._abort_sglang_request(rid)
454
+ except Exception as exc:
455
+ logger.warning(f"Error while sending abort_request for rid={rid}: {exc}")
456
+ status = 499
457
+ raise
343
458
  except Exception as exc:
344
459
  logger.error(f"Error performing non-streamed call: {exc}")
345
460
  status = 500
@@ -371,7 +486,9 @@ class Cord:
371
486
  encrypt = getattr(request.state, "_encrypt", None)
372
487
  try:
373
488
  if self._passthrough:
374
- async with self._passthrough_call(**kwargs) as response:
489
+ rid = getattr(request.state, "sglang_rid", None)
490
+
491
+ async with self._passthrough_call(request, **kwargs) as response:
375
492
  if not 200 <= response.status < 300:
376
493
  try:
377
494
  error_detail = await response.json()
@@ -385,6 +502,13 @@ class Cord:
385
502
  detail=error_detail,
386
503
  )
387
504
  async for content in response.content:
505
+ if await request.is_disconnected():
506
+ logger.info(
507
+ f"Client disconnected for {self._func.__name__}, aborting upstream (rid={rid})"
508
+ )
509
+ await self._abort_sglang_request(rid)
510
+ break
511
+
388
512
  if encrypt:
389
513
  yield encrypt(content) + "\n"
390
514
  else:
@@ -402,6 +526,20 @@ class Cord:
402
526
  logger.success(
403
527
  f"Completed request [{self._func.__name__}] in {time.time() - started_at} seconds"
404
528
  )
529
+ except asyncio.CancelledError:
530
+ rid = getattr(request.state, "sglang_rid", None)
531
+ if self._is_sglang_passthrough():
532
+ logger.info(
533
+ f"Streaming cancelled for {self._func.__name__} "
534
+ f"(likely client disconnect), aborting SGLang rid={rid}"
535
+ )
536
+ try:
537
+ await self._abort_sglang_request(rid)
538
+ except Exception as exc:
539
+ logger.warning(f"Error while sending abort_request for rid={rid}: {exc}")
540
+ status = 499
541
+ raise
542
+
405
543
  except Exception as exc:
406
544
  logger.error(f"Error performing stream call: {exc}")
407
545
  status = 500
@@ -418,6 +556,21 @@ class Cord:
418
556
  status=status,
419
557
  ).observe(time.time() - started_at)
420
558
 
559
+ async def _abort_sglang_request(self, rid: Optional[str]):
560
+ if not rid or not self._is_sglang_passthrough():
561
+ return
562
+ try:
563
+ async with aiohttp.ClientSession(
564
+ timeout=aiohttp.ClientTimeout(connect=5.0, total=15.0),
565
+ base_url=f"http://127.0.0.1:{self._passthrough_port or 8000}",
566
+ headers=self._app.passthrough_headers or {},
567
+ ) as session:
568
+ logger.warning(f"Aborting SGLang request {rid=}")
569
+ await session.post("/abort_request", json={"rid": rid})
570
+ logger.success(f"Sent SGLang abort_request for rid={rid}")
571
+ except Exception as exc:
572
+ logger.warning(f"Failed to send abort_request for rid={rid}: {exc}")
573
+
421
574
  async def _request_handler(self, request: Request):
422
575
  """
423
576
  Decode/deserialize incoming request and call the appropriate function.
@@ -448,6 +601,13 @@ class Cord:
448
601
  else:
449
602
  args = []
450
603
  kwargs = {"json": request.state.decrypted} if request.state.decrypted else {}
604
+
605
+ # Set a custom request ID for SGLang passthroughs.
606
+ if self._is_sglang_passthrough() and isinstance(kwargs.get("json"), dict):
607
+ rid = uuid.uuid4().hex
608
+ kwargs["json"].setdefault("rid", rid)
609
+ request.state.sglang_rid = rid
610
+
451
611
  if not self._passthrough:
452
612
  if self.input_models and all([isinstance(args[idx], dict) for idx in range(len(args))]):
453
613
  try:
@@ -152,6 +152,7 @@ def build_diffusion_chute(
152
152
  max_instances: int = 1,
153
153
  scaling_threshold: float = 0.75,
154
154
  shutdown_after_seconds: int = 300,
155
+ tee: bool = False,
155
156
  ):
156
157
  chute = Chute(
157
158
  username=username,
@@ -165,6 +166,7 @@ def build_diffusion_chute(
165
166
  shutdown_after_seconds=shutdown_after_seconds,
166
167
  max_instances=max_instances,
167
168
  scaling_threshold=scaling_threshold,
169
+ tee=tee,
168
170
  )
169
171
 
170
172
  @chute.on_startup()