sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ import threading
27
27
  import time
28
28
  import uuid
29
29
  from collections import deque
30
+ from contextlib import nullcontext
30
31
  from datetime import datetime
31
32
  from http import HTTPStatus
32
33
  from typing import (
@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
69
70
  BatchMultimodalOut,
70
71
  BatchStrOut,
71
72
  BatchTokenIDOut,
73
+ BlockReqType,
72
74
  CloseSessionReqInput,
73
75
  ConfigureLoggingReq,
74
76
  EmbeddingReqInput,
@@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import (
114
116
  )
115
117
  from sglang.srt.managers.mm_utils import TensorTransportMode
116
118
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
119
+ from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
117
120
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
118
121
  from sglang.srt.sampling.sampling_params import SamplingParams
119
122
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -167,16 +170,6 @@ class ReqState:
167
170
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
168
171
 
169
172
 
170
- def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
171
- is_cross_node = server_args.dist_init_addr
172
-
173
- if is_cross_node:
174
- # Fallback to default CPU transport for multi-node
175
- return "default"
176
- else:
177
- return "cuda_ipc"
178
-
179
-
180
173
  class TokenizerManager:
181
174
  """TokenizerManager is a process that tokenizes the text."""
182
175
 
@@ -196,16 +189,6 @@ class TokenizerManager:
196
189
  else None
197
190
  )
198
191
  self.crash_dump_folder = server_args.crash_dump_folder
199
- self.crash_dump_performed = False # Flag to ensure dump is only called once
200
-
201
- # Init inter-process communication
202
- context = zmq.asyncio.Context(2)
203
- self.recv_from_detokenizer = get_zmq_socket(
204
- context, zmq.PULL, port_args.tokenizer_ipc_name, True
205
- )
206
- self.send_to_scheduler = get_zmq_socket(
207
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
208
- )
209
192
 
210
193
  # Read model args
211
194
  self.model_path = server_args.model_path
@@ -215,8 +198,7 @@ class TokenizerManager:
215
198
  self.is_image_gen = self.model_config.is_image_gen
216
199
  self.context_len = self.model_config.context_len
217
200
  self.image_token_id = self.model_config.image_token_id
218
- self._updating = False
219
- self._cond = asyncio.Condition()
201
+ self.max_req_input_len = None # Will be set later in engine.py
220
202
 
221
203
  if self.model_config.is_multimodal:
222
204
  import_processors()
@@ -255,39 +237,57 @@ class TokenizerManager:
255
237
  revision=server_args.revision,
256
238
  )
257
239
 
258
- # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
259
- # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
260
- # serves as the source of truth for available adapters and maps user-friendly LoRA names
261
- # to internally used unique LoRA IDs.
262
- self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
240
+ # Init inter-process communication
241
+ context = zmq.asyncio.Context(2)
242
+ self.recv_from_detokenizer = get_zmq_socket(
243
+ context, zmq.PULL, port_args.tokenizer_ipc_name, True
244
+ )
245
+ self.send_to_scheduler = get_zmq_socket(
246
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
247
+ )
263
248
 
264
- # Store states
249
+ # Request states
265
250
  self.no_create_loop = False
266
251
  self.rid_to_state: Dict[str, ReqState] = {}
252
+ self.asyncio_tasks = set()
253
+
254
+ # Health check
267
255
  self.health_check_failed = False
268
256
  self.gracefully_exit = False
269
257
  self.last_receive_tstamp = 0
258
+
259
+ # Dumping
270
260
  self.dump_requests_folder = "" # By default do not dump
271
261
  self.dump_requests_threshold = 1000
272
262
  self.dump_request_list: List[Tuple] = []
273
- self.crash_dump_request_list: deque[Tuple] = deque()
274
263
  self.log_request_metadata = self.get_log_request_metadata()
264
+ self.crash_dump_request_list: deque[Tuple] = deque()
265
+ self.crash_dump_performed = False # Flag to ensure dump is only called once
266
+
267
+ # Session
275
268
  self.session_futures = {} # session_id -> asyncio event
276
- self.max_req_input_len = None
277
- self.asyncio_tasks = set()
278
269
 
270
+ # Weight updates
279
271
  # The event to notify the weight sync is finished.
280
272
  self.model_update_lock = RWLock()
281
273
  self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
282
274
  None
283
275
  )
276
+ self._is_updating = False
277
+ self._is_updating_cond = asyncio.Condition()
284
278
 
279
+ # LoRA
280
+ # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
281
+ # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
282
+ # serves as the source of truth for available adapters and maps user-friendly LoRA names
283
+ # to internally used unique LoRA IDs.
284
+ self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
285
285
  # Lock to serialize LoRA update operations.
286
286
  # Please note that, unlike `model_update_lock`, this does not block inference, allowing
287
287
  # LoRA updates and inference to overlap.
288
288
  self.lora_update_lock = asyncio.Lock()
289
289
 
290
- # For pd disaggregtion
290
+ # For PD disaggregtion
291
291
  self.disaggregation_mode = DisaggregationMode(
292
292
  self.server_args.disaggregation_mode
293
293
  )
@@ -455,17 +455,11 @@ class TokenizerManager:
455
455
  request: Optional[fastapi.Request] = None,
456
456
  ):
457
457
  created_time = time.time()
458
- async with self._cond:
459
- await self._cond.wait_for(lambda: not self._updating)
460
-
461
458
  self.auto_create_handle_loop()
462
459
  obj.normalize_batch_and_arguments()
463
460
 
464
- if isinstance(obj, EmbeddingReqInput) and self.is_generation:
465
- raise ValueError(
466
- "This model does not appear to be an embedding model by default. "
467
- "Please add `--is-embedding` when launching the server or try another model."
468
- )
461
+ async with self._is_updating_cond:
462
+ await self._is_updating_cond.wait_for(lambda: not self._is_updating)
469
463
 
470
464
  if self.log_requests:
471
465
  max_length, skip_names, _ = self.log_request_metadata
@@ -564,6 +558,12 @@ class TokenizerManager:
564
558
  f"model's context length ({self.context_len} tokens)."
565
559
  )
566
560
 
561
+ if isinstance(obj, EmbeddingReqInput) and self.is_generation:
562
+ raise ValueError(
563
+ "This model does not appear to be an embedding model by default. "
564
+ "Please add `--is-embedding` when launching the server or try another model."
565
+ )
566
+
567
567
  # Check total tokens (input + max_new_tokens)
568
568
  max_new_tokens = obj.sampling_params.get("max_new_tokens")
569
569
  if (
@@ -766,6 +766,19 @@ class TokenizerManager:
766
766
  ):
767
767
  raise ValueError(finish_reason["message"])
768
768
 
769
+ if (
770
+ finish_reason.get("type") == "abort"
771
+ and finish_reason.get("status_code")
772
+ == HTTPStatus.SERVICE_UNAVAILABLE
773
+ ):
774
+ # This is an abort request initiated by scheduler.
775
+ # Delete the key to prevent resending abort request to the scheduler and
776
+ # to ensure aborted request state is cleaned up.
777
+ del self.rid_to_state[state.obj.rid]
778
+ raise fastapi.HTTPException(
779
+ status_code=finish_reason["status_code"],
780
+ detail=finish_reason["message"],
781
+ )
769
782
  yield out
770
783
  break
771
784
 
@@ -806,12 +819,21 @@ class TokenizerManager:
806
819
  rids.append(tmp_obj.rid)
807
820
  else:
808
821
  # Sequential tokenization and processing
809
- for i in range(batch_size):
810
- tmp_obj = obj[i]
811
- tokenized_obj = await self._tokenize_one_request(tmp_obj)
812
- state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
813
- generators.append(self._wait_one_response(tmp_obj, state, request))
814
- rids.append(tmp_obj.rid)
822
+ with (
823
+ input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
824
+ if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
825
+ else nullcontext()
826
+ ):
827
+ for i in range(batch_size):
828
+ tmp_obj = obj[i]
829
+ tokenized_obj = await self._tokenize_one_request(tmp_obj)
830
+ state = self._send_one_request(
831
+ tmp_obj, tokenized_obj, created_time
832
+ )
833
+ generators.append(
834
+ self._wait_one_response(tmp_obj, state, request)
835
+ )
836
+ rids.append(tmp_obj.rid)
815
837
  else:
816
838
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
817
839
  if batch_size > 128:
@@ -934,14 +956,14 @@ class TokenizerManager:
934
956
  await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
935
957
 
936
958
  async def pause_generation(self):
937
- async with self._cond:
938
- self._updating = True
959
+ async with self._is_updating_cond:
960
+ self._is_updating = True
939
961
  self.abort_request(abort_all=True)
940
962
 
941
963
  async def continue_generation(self):
942
- async with self._cond:
943
- self._updating = False
944
- self._cond.notify_all()
964
+ async with self._is_updating_cond:
965
+ self._is_updating = False
966
+ self._is_updating_cond.notify_all()
945
967
 
946
968
  async def update_weights_from_disk(
947
969
  self,
@@ -1183,14 +1205,6 @@ class TokenizerManager:
1183
1205
  # Many DP ranks
1184
1206
  return [res.internal_state for res in responses]
1185
1207
 
1186
- async def get_load(self) -> dict:
1187
- # TODO(lsyin): fake load report server
1188
- if not self.current_load_lock.locked():
1189
- async with self.current_load_lock:
1190
- internal_state = await self.get_internal_state()
1191
- self.current_load = internal_state[0]["load"]
1192
- return {"load": self.current_load}
1193
-
1194
1208
  async def set_internal_state(
1195
1209
  self, obj: SetInternalStateReq
1196
1210
  ) -> SetInternalStateReqOutput:
@@ -1199,6 +1213,14 @@ class TokenizerManager:
1199
1213
  )
1200
1214
  return [res.internal_state for res in responses]
1201
1215
 
1216
+ async def get_load(self) -> dict:
1217
+ # TODO(lsyin): fake load report server
1218
+ if not self.current_load_lock.locked():
1219
+ async with self.current_load_lock:
1220
+ internal_state = await self.get_internal_state()
1221
+ self.current_load = internal_state[0]["load"]
1222
+ return {"load": self.current_load}
1223
+
1202
1224
  def get_log_request_metadata(self):
1203
1225
  max_length = None
1204
1226
  skip_names = None
@@ -1318,11 +1340,24 @@ class TokenizerManager:
1318
1340
  "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
1319
1341
  )
1320
1342
  return
1321
- logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
1322
- self.crash_dump_performed = True
1343
+
1323
1344
  if not self.crash_dump_folder:
1324
1345
  return
1325
1346
 
1347
+ logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
1348
+ self.crash_dump_performed = True
1349
+
1350
+ # Check if NFS directory is available
1351
+ # expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
1352
+ # use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
1353
+ # expected_nfs_dir, os.W_OK
1354
+ # )
1355
+ use_nfs_dir = False
1356
+ if not use_nfs_dir:
1357
+ logger.error(
1358
+ f"Expected NFS directory is not available or writable. Uploading to GCS."
1359
+ )
1360
+
1326
1361
  data_to_dump = []
1327
1362
  if self.crash_dump_request_list:
1328
1363
  data_to_dump.extend(self.crash_dump_request_list)
@@ -1332,7 +1367,12 @@ class TokenizerManager:
1332
1367
  for rid, state in self.rid_to_state.items():
1333
1368
  if not state.finished:
1334
1369
  unfinished_requests.append(
1335
- (state.obj, {}, state.created_time, time.time())
1370
+ (
1371
+ state.obj,
1372
+ state.out_list[-1] if state.out_list else {},
1373
+ state.created_time,
1374
+ time.time(),
1375
+ )
1336
1376
  )
1337
1377
  if unfinished_requests:
1338
1378
  data_to_dump.extend(unfinished_requests)
@@ -1340,10 +1380,11 @@ class TokenizerManager:
1340
1380
  if not data_to_dump:
1341
1381
  return
1342
1382
 
1383
+ object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
1343
1384
  filename = os.path.join(
1344
1385
  self.crash_dump_folder,
1345
1386
  os.getenv("HOSTNAME", None),
1346
- f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
1387
+ object_name,
1347
1388
  )
1348
1389
 
1349
1390
  os.makedirs(os.path.dirname(filename), exist_ok=True)
@@ -1358,6 +1399,24 @@ class TokenizerManager:
1358
1399
  f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
1359
1400
  )
1360
1401
 
1402
+ def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
1403
+ from google.cloud import storage
1404
+
1405
+ client = storage.Client()
1406
+ bucket = client.bucket(bucket_name)
1407
+ blob = bucket.blob(object_name)
1408
+ blob.upload_from_filename(source_file_path, if_generation_match=0)
1409
+ logger.error(
1410
+ f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
1411
+ )
1412
+
1413
+ if not use_nfs_dir:
1414
+ _upload_file_to_gcs(
1415
+ "sglang_crash_dump",
1416
+ filename,
1417
+ os.getenv("HOSTNAME", None) + "/" + object_name,
1418
+ )
1419
+
1361
1420
  async def sigterm_watchdog(self):
1362
1421
  while not self.gracefully_exit:
1363
1422
  await asyncio.sleep(5)
@@ -1401,7 +1460,7 @@ class TokenizerManager:
1401
1460
  while True:
1402
1461
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1403
1462
  self._result_dispatcher(recv_obj)
1404
- self.last_receive_tstamp = time.perf_counter()
1463
+ self.last_receive_tstamp = time.time()
1405
1464
 
1406
1465
  def _handle_batch_output(
1407
1466
  self,
@@ -1672,24 +1731,13 @@ class TokenizerManager:
1672
1731
  self.dump_requests_folder,
1673
1732
  datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
1674
1733
  )
1675
- logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
1676
-
1677
- to_dump = self.dump_request_list
1734
+ self._dump_data_to_file(
1735
+ data_list=self.dump_request_list,
1736
+ filename=filename,
1737
+ log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
1738
+ )
1678
1739
  self.dump_request_list = []
1679
1740
 
1680
- to_dump_with_server_args = {
1681
- "server_args": self.server_args,
1682
- "requests": to_dump,
1683
- }
1684
-
1685
- def background_task():
1686
- os.makedirs(self.dump_requests_folder, exist_ok=True)
1687
- with open(filename, "wb") as f:
1688
- pickle.dump(to_dump_with_server_args, f)
1689
-
1690
- # Schedule the task to run in the background without awaiting it
1691
- asyncio.create_task(asyncio.to_thread(background_task))
1692
-
1693
1741
  def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
1694
1742
  current_time = time.time()
1695
1743
  self.crash_dump_request_list.append(
@@ -1702,11 +1750,34 @@ class TokenizerManager:
1702
1750
  ):
1703
1751
  self.crash_dump_request_list.popleft()
1704
1752
 
1753
+ def _dump_data_to_file(
1754
+ self, data_list: List[Tuple], filename: str, log_message: str
1755
+ ):
1756
+ logger.info(log_message)
1757
+ to_dump_with_server_args = {
1758
+ "server_args": self.server_args,
1759
+ "requests": data_list.copy(),
1760
+ }
1761
+
1762
+ def background_task():
1763
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
1764
+ with open(filename, "wb") as f:
1765
+ pickle.dump(to_dump_with_server_args, f)
1766
+
1767
+ asyncio.create_task(asyncio.to_thread(background_task))
1768
+
1705
1769
  def _handle_abort_req(self, recv_obj):
1706
1770
  state = self.rid_to_state[recv_obj.rid]
1707
1771
  state.finished = True
1708
- state.out_list.append(
1709
- {
1772
+ if recv_obj.finished_reason:
1773
+ out = {
1774
+ "meta_info": {
1775
+ "id": recv_obj.rid,
1776
+ "finish_reason": recv_obj.finished_reason,
1777
+ },
1778
+ }
1779
+ else:
1780
+ out = {
1710
1781
  "text": "",
1711
1782
  "meta_info": {
1712
1783
  "id": recv_obj.rid,
@@ -1718,7 +1789,7 @@ class TokenizerManager:
1718
1789
  "completion_tokens": 0,
1719
1790
  },
1720
1791
  }
1721
- )
1792
+ state.out_list.append(out)
1722
1793
  state.event.set()
1723
1794
 
1724
1795
  def _handle_open_session_req_output(self, recv_obj):
@@ -1830,6 +1901,16 @@ class TokenizerManager:
1830
1901
  return scores
1831
1902
 
1832
1903
 
1904
+ def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
1905
+ is_cross_node = server_args.dist_init_addr
1906
+
1907
+ if is_cross_node:
1908
+ # Fallback to default CPU transport for multi-node
1909
+ return "default"
1910
+ else:
1911
+ return "cuda_ipc"
1912
+
1913
+
1833
1914
  async def print_exception_wrapper(func):
1834
1915
  """
1835
1916
  Sometimes an asyncio function does not print exception.
@@ -1910,8 +1991,10 @@ class _Communicator(Generic[T]):
1910
1991
  #
1911
1992
  # | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
1912
1993
  # | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
1994
+ # | http | yes | validation | background task | fast api | del in _handle_abort_req |
1913
1995
  # | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
1914
1996
  # | http | yes | running | background task | fast api | del in _handle_batch_output |
1997
+ # | http | no | validation | http exception | http exception | del in _handle_abort_req |
1915
1998
  # | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
1916
1999
  # | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
1917
2000
  #
@@ -56,6 +56,7 @@ class TpModelWorker:
56
56
  server_args: ServerArgs,
57
57
  gpu_id: int,
58
58
  tp_rank: int,
59
+ moe_ep_rank: int,
59
60
  pp_rank: int,
60
61
  dp_rank: Optional[int],
61
62
  nccl_port: int,
@@ -66,6 +67,7 @@ class TpModelWorker:
66
67
  # Parse args
67
68
  self.tp_size = server_args.tp_size
68
69
  self.tp_rank = tp_rank
70
+ self.moe_ep_rank = moe_ep_rank
69
71
  self.pp_rank = pp_rank
70
72
 
71
73
  # Init model and tokenizer
@@ -85,6 +87,8 @@ class TpModelWorker:
85
87
  gpu_id=gpu_id,
86
88
  tp_rank=tp_rank,
87
89
  tp_size=server_args.tp_size,
90
+ moe_ep_rank=moe_ep_rank,
91
+ moe_ep_size=server_args.ep_size,
88
92
  pp_rank=pp_rank,
89
93
  pp_size=server_args.pp_size,
90
94
  nccl_port=nccl_port,
@@ -130,6 +134,10 @@ class TpModelWorker:
130
134
  self.model_runner.req_to_token_pool.size,
131
135
  )
132
136
  assert self.max_running_requests > 0, "max_running_request is zero"
137
+ self.max_queued_requests = server_args.max_queued_requests
138
+ assert (
139
+ self.max_running_requests > 0
140
+ ), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
133
141
  self.max_req_len = min(
134
142
  self.model_config.context_len - 1,
135
143
  self.max_total_num_tokens - 1,
@@ -165,6 +173,7 @@ class TpModelWorker:
165
173
  self.max_total_num_tokens,
166
174
  self.max_prefill_tokens,
167
175
  self.max_running_requests,
176
+ self.max_queued_requests,
168
177
  self.max_req_len,
169
178
  self.max_req_input_len,
170
179
  self.random_seed,
@@ -58,13 +58,14 @@ class TpModelWorkerClient:
58
58
  server_args: ServerArgs,
59
59
  gpu_id: int,
60
60
  tp_rank: int,
61
+ moe_ep_rank: int,
61
62
  pp_rank: int,
62
63
  dp_rank: Optional[int],
63
64
  nccl_port: int,
64
65
  ):
65
66
  # Load the model
66
67
  self.worker = TpModelWorker(
67
- server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
68
+ server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port
68
69
  )
69
70
  self.max_running_requests = self.worker.max_running_requests
70
71
  self.device = self.worker.device
@@ -2,7 +2,7 @@ import hashlib
2
2
  import logging
3
3
  import os
4
4
  from abc import ABC, abstractmethod
5
- from typing import List, Optional
5
+ from typing import Any, List, Optional
6
6
 
7
7
  import torch
8
8
 
@@ -39,7 +39,10 @@ class HiCacheStorage(ABC):
39
39
 
40
40
  @abstractmethod
41
41
  def get(
42
- self, key: str, target_location: Optional[torch.Tensor] = None
42
+ self,
43
+ key: str,
44
+ target_location: Optional[Any] = None,
45
+ target_sizes: Optional[Any] = None,
43
46
  ) -> torch.Tensor | None:
44
47
  """
45
48
  Retrieve the value associated with the given key.
@@ -49,7 +52,10 @@ class HiCacheStorage(ABC):
49
52
 
50
53
  @abstractmethod
51
54
  def batch_get(
52
- self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
55
+ self,
56
+ keys: List[str],
57
+ target_locations: Optional[Any] = None,
58
+ target_sizes: Optional[Any] = None,
53
59
  ) -> List[torch.Tensor | None]:
54
60
  """
55
61
  Retrieve values for multiple keys.
@@ -58,7 +64,13 @@ class HiCacheStorage(ABC):
58
64
  pass
59
65
 
60
66
  @abstractmethod
61
- def set(self, key, value) -> bool:
67
+ def set(
68
+ self,
69
+ key: str,
70
+ value: Optional[Any] = None,
71
+ target_location: Optional[Any] = None,
72
+ target_sizes: Optional[Any] = None,
73
+ ) -> bool:
62
74
  """
63
75
  Store the value associated with the given key.
64
76
  Returns True if the operation was successful, False otherwise.
@@ -66,7 +78,13 @@ class HiCacheStorage(ABC):
66
78
  pass
67
79
 
68
80
  @abstractmethod
69
- def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
81
+ def batch_set(
82
+ self,
83
+ keys: List[str],
84
+ values: Optional[Any] = None,
85
+ target_locations: Optional[Any] = None,
86
+ target_sizes: Optional[Any] = None,
87
+ ) -> bool:
70
88
  """
71
89
  Store multiple key-value pairs.
72
90
  Returns True if all operations were successful, False otherwise.
@@ -74,7 +92,7 @@ class HiCacheStorage(ABC):
74
92
  pass
75
93
 
76
94
  @abstractmethod
77
- def exists(self, key: str) -> bool:
95
+ def exists(self, key: str) -> bool | dict:
78
96
  """
79
97
  Check if the key exists in the storage.
80
98
  Returns True if the key exists, False otherwise.
@@ -85,7 +103,7 @@ class HiCacheStorage(ABC):
85
103
  class HiCacheFile(HiCacheStorage):
86
104
 
87
105
  def __init__(self, file_path: str = "/tmp/hicache"):
88
- self.file_path = file_path
106
+ self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
89
107
  tp_rank = get_tensor_model_parallel_rank()
90
108
  tp_size = get_tensor_model_parallel_world_size()
91
109
  self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
@@ -97,7 +115,10 @@ class HiCacheFile(HiCacheStorage):
97
115
  return key + self.tp_suffix
98
116
 
99
117
  def get(
100
- self, key: str, target_location: Optional[torch.Tensor] = None
118
+ self,
119
+ key: str,
120
+ target_location: Optional[Any] = None,
121
+ target_sizes: Optional[Any] = None,
101
122
  ) -> torch.Tensor | None:
102
123
  key = self._get_suffixed_key(key)
103
124
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
@@ -115,7 +136,8 @@ class HiCacheFile(HiCacheStorage):
115
136
  def batch_get(
116
137
  self,
117
138
  keys: List[str],
118
- target_locations: Optional[List[torch.Tensor]] = None,
139
+ target_locations: Optional[Any] = None,
140
+ target_sizes: Optional[Any] = None,
119
141
  ) -> List[torch.Tensor | None]:
120
142
  return [
121
143
  self.get(key, target_location)
@@ -124,7 +146,13 @@ class HiCacheFile(HiCacheStorage):
124
146
  )
125
147
  ]
126
148
 
127
- def set(self, key: str, value: torch.Tensor) -> bool:
149
+ def set(
150
+ self,
151
+ key: str,
152
+ value: Optional[Any] = None,
153
+ target_location: Optional[Any] = None,
154
+ target_sizes: Optional[Any] = None,
155
+ ) -> bool:
128
156
  key = self._get_suffixed_key(key)
129
157
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
130
158
  if self.exists(key):
@@ -137,7 +165,13 @@ class HiCacheFile(HiCacheStorage):
137
165
  logger.error(f"Failed to save tensor {key}: {e}")
138
166
  return False
139
167
 
140
- def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
168
+ def batch_set(
169
+ self,
170
+ keys: List[str],
171
+ values: Optional[Any] = None,
172
+ target_locations: Optional[Any] = None,
173
+ target_sizes: Optional[Any] = None,
174
+ ) -> bool:
141
175
  for key, value in zip(keys, values):
142
176
  if not self.set(key, value):
143
177
  return False