sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -53,7 +53,7 @@ class TemplateManager:
53
53
  def __init__(self):
54
54
  self._chat_template_name: Optional[str] = None
55
55
  self._completion_template_name: Optional[str] = None
56
- self._jinja_template_content_format: Optional[str] = None
56
+ self._jinja_template_content_format: Optional[str] = "openai"
57
57
 
58
58
  @property
59
59
  def chat_template_name(self) -> Optional[str]:
@@ -71,31 +71,60 @@ class TemplateManager:
71
71
  return self._jinja_template_content_format
72
72
 
73
73
  def load_chat_template(
74
- self, tokenizer_manager, chat_template_arg: str, model_path: str
74
+ self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
75
75
  ) -> None:
76
76
  """
77
77
  Load a chat template from various sources.
78
78
 
79
79
  Args:
80
80
  tokenizer_manager: The tokenizer manager instance
81
- chat_template_arg: Template name or file path
81
+ chat_template_arg: Template name, file path, or None to auto-detect
82
82
  model_path: Path to the model
83
83
  """
84
- logger.info(f"Loading chat template: {chat_template_arg}")
84
+ if chat_template_arg:
85
+ self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
86
+ else:
87
+ # Try HuggingFace template first
88
+ hf_template = self._resolve_hf_chat_template(tokenizer_manager)
89
+ if hf_template:
90
+ self._jinja_template_content_format = (
91
+ detect_jinja_template_content_format(hf_template)
92
+ )
93
+ logger.info(
94
+ f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
95
+ )
96
+ return
85
97
 
86
- if not chat_template_exists(chat_template_arg):
87
- if not os.path.exists(chat_template_arg):
88
- raise RuntimeError(
89
- f"Chat template {chat_template_arg} is not a built-in template name "
90
- "or a valid chat template file path."
98
+ # Fallback to SGLang template guessing
99
+ self.guess_chat_template_from_model_path(model_path)
100
+
101
+ # Set default format if no template was found
102
+ if self._chat_template_name is None:
103
+ self._jinja_template_content_format = "string"
104
+ logger.info(
105
+ "No chat template found, defaulting to 'string' content format"
91
106
  )
92
107
 
93
- if chat_template_arg.endswith(".jinja"):
94
- self._load_jinja_template(tokenizer_manager, chat_template_arg)
95
- else:
96
- self._load_json_chat_template(chat_template_arg)
97
- else:
108
+ def _load_explicit_chat_template(
109
+ self, tokenizer_manager, chat_template_arg: str
110
+ ) -> None:
111
+ """Load explicitly specified chat template."""
112
+ logger.info(f"Loading chat template from argument: {chat_template_arg}")
113
+
114
+ if chat_template_exists(chat_template_arg):
98
115
  self._chat_template_name = chat_template_arg
116
+ return
117
+
118
+ if not os.path.exists(chat_template_arg):
119
+ raise RuntimeError(
120
+ f"Chat template {chat_template_arg} is not a built-in template name "
121
+ "or a valid chat template file path."
122
+ )
123
+
124
+ if chat_template_arg.endswith(".jinja"):
125
+ self._load_jinja_template(tokenizer_manager, chat_template_arg)
126
+ else:
127
+ self._load_json_chat_template(chat_template_arg)
99
128
 
100
129
  def guess_chat_template_from_model_path(self, model_path: str) -> None:
101
130
  """
@@ -146,10 +175,7 @@ class TemplateManager:
146
175
  completion_template: Optional completion template name/path
147
176
  """
148
177
  # Load chat template
149
- if chat_template:
150
- self.load_chat_template(tokenizer_manager, chat_template, model_path)
151
- else:
152
- self.guess_chat_template_from_model_path(model_path)
178
+ self.load_chat_template(tokenizer_manager, chat_template, model_path)
153
179
 
154
180
  # Load completion template
155
181
  if completion_template:
@@ -166,7 +192,7 @@ class TemplateManager:
166
192
  chat_template
167
193
  )
168
194
  logger.info(
169
- f"Detected chat template content format: {self._jinja_template_content_format}"
195
+ f"Detected user specified Jinja chat template with content format: {self._jinja_template_content_format}"
170
196
  )
171
197
 
172
198
  def _load_json_chat_template(self, template_path: str) -> None:
@@ -224,3 +250,20 @@ class TemplateManager:
224
250
  override=True,
225
251
  )
226
252
  self._completion_template_name = template["name"]
253
+
254
+ def _resolve_hf_chat_template(self, tokenizer_manager) -> Optional[str]:
255
+ """
256
+ Resolve HuggingFace chat template.
257
+
258
+ Returns the chat template string if found, None otherwise.
259
+ """
260
+ tokenizer = tokenizer_manager.tokenizer
261
+
262
+ # Try to get AutoTokenizer chat template
263
+ try:
264
+ return tokenizer.get_chat_template()
265
+ except Exception as e:
266
+ logger.debug(f"Error getting chat template via get_chat_template(): {e}")
267
+
268
+ logger.debug("No HuggingFace chat template found")
269
+ return None
@@ -170,16 +170,6 @@ class ReqState:
170
170
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
171
171
 
172
172
 
173
- def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
174
- is_cross_node = server_args.dist_init_addr
175
-
176
- if is_cross_node:
177
- # Fallback to default CPU transport for multi-node
178
- return "default"
179
- else:
180
- return "cuda_ipc"
181
-
182
-
183
173
  class TokenizerManager:
184
174
  """TokenizerManager is a process that tokenizes the text."""
185
175
 
@@ -199,16 +189,6 @@ class TokenizerManager:
199
189
  else None
200
190
  )
201
191
  self.crash_dump_folder = server_args.crash_dump_folder
202
- self.crash_dump_performed = False # Flag to ensure dump is only called once
203
-
204
- # Init inter-process communication
205
- context = zmq.asyncio.Context(2)
206
- self.recv_from_detokenizer = get_zmq_socket(
207
- context, zmq.PULL, port_args.tokenizer_ipc_name, True
208
- )
209
- self.send_to_scheduler = get_zmq_socket(
210
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
211
- )
212
192
 
213
193
  # Read model args
214
194
  self.model_path = server_args.model_path
@@ -218,8 +198,7 @@ class TokenizerManager:
218
198
  self.is_image_gen = self.model_config.is_image_gen
219
199
  self.context_len = self.model_config.context_len
220
200
  self.image_token_id = self.model_config.image_token_id
221
- self._updating = False
222
- self._cond = asyncio.Condition()
201
+ self.max_req_input_len = None # Will be set later in engine.py
223
202
 
224
203
  if self.model_config.is_multimodal:
225
204
  import_processors()
@@ -258,39 +237,57 @@ class TokenizerManager:
258
237
  revision=server_args.revision,
259
238
  )
260
239
 
261
- # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
262
- # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
263
- # serves as the source of truth for available adapters and maps user-friendly LoRA names
264
- # to internally used unique LoRA IDs.
265
- self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
240
+ # Init inter-process communication
241
+ context = zmq.asyncio.Context(2)
242
+ self.recv_from_detokenizer = get_zmq_socket(
243
+ context, zmq.PULL, port_args.tokenizer_ipc_name, True
244
+ )
245
+ self.send_to_scheduler = get_zmq_socket(
246
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
247
+ )
266
248
 
267
- # Store states
249
+ # Request states
268
250
  self.no_create_loop = False
269
251
  self.rid_to_state: Dict[str, ReqState] = {}
252
+ self.asyncio_tasks = set()
253
+
254
+ # Health check
270
255
  self.health_check_failed = False
271
256
  self.gracefully_exit = False
272
257
  self.last_receive_tstamp = 0
258
+
259
+ # Dumping
273
260
  self.dump_requests_folder = "" # By default do not dump
274
261
  self.dump_requests_threshold = 1000
275
262
  self.dump_request_list: List[Tuple] = []
276
- self.crash_dump_request_list: deque[Tuple] = deque()
277
263
  self.log_request_metadata = self.get_log_request_metadata()
264
+ self.crash_dump_request_list: deque[Tuple] = deque()
265
+ self.crash_dump_performed = False # Flag to ensure dump is only called once
266
+
267
+ # Session
278
268
  self.session_futures = {} # session_id -> asyncio event
279
- self.max_req_input_len = None
280
- self.asyncio_tasks = set()
281
269
 
270
+ # Weight updates
282
271
  # The event to notify the weight sync is finished.
283
272
  self.model_update_lock = RWLock()
284
273
  self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
285
274
  None
286
275
  )
276
+ self._is_updating = False
277
+ self._is_updating_cond = asyncio.Condition()
287
278
 
279
+ # LoRA
280
+ # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
281
+ # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
282
+ # serves as the source of truth for available adapters and maps user-friendly LoRA names
283
+ # to internally used unique LoRA IDs.
284
+ self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
288
285
  # Lock to serialize LoRA update operations.
289
286
  # Please note that, unlike `model_update_lock`, this does not block inference, allowing
290
287
  # LoRA updates and inference to overlap.
291
288
  self.lora_update_lock = asyncio.Lock()
292
289
 
293
- # For pd disaggregtion
290
+ # For PD disaggregtion
294
291
  self.disaggregation_mode = DisaggregationMode(
295
292
  self.server_args.disaggregation_mode
296
293
  )
@@ -458,17 +455,11 @@ class TokenizerManager:
458
455
  request: Optional[fastapi.Request] = None,
459
456
  ):
460
457
  created_time = time.time()
461
- async with self._cond:
462
- await self._cond.wait_for(lambda: not self._updating)
463
-
464
458
  self.auto_create_handle_loop()
465
459
  obj.normalize_batch_and_arguments()
466
460
 
467
- if isinstance(obj, EmbeddingReqInput) and self.is_generation:
468
- raise ValueError(
469
- "This model does not appear to be an embedding model by default. "
470
- "Please add `--is-embedding` when launching the server or try another model."
471
- )
461
+ async with self._is_updating_cond:
462
+ await self._is_updating_cond.wait_for(lambda: not self._is_updating)
472
463
 
473
464
  if self.log_requests:
474
465
  max_length, skip_names, _ = self.log_request_metadata
@@ -567,6 +558,12 @@ class TokenizerManager:
567
558
  f"model's context length ({self.context_len} tokens)."
568
559
  )
569
560
 
561
+ if isinstance(obj, EmbeddingReqInput) and self.is_generation:
562
+ raise ValueError(
563
+ "This model does not appear to be an embedding model by default. "
564
+ "Please add `--is-embedding` when launching the server or try another model."
565
+ )
566
+
570
567
  # Check total tokens (input + max_new_tokens)
571
568
  max_new_tokens = obj.sampling_params.get("max_new_tokens")
572
569
  if (
@@ -959,14 +956,14 @@ class TokenizerManager:
959
956
  await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
960
957
 
961
958
  async def pause_generation(self):
962
- async with self._cond:
963
- self._updating = True
959
+ async with self._is_updating_cond:
960
+ self._is_updating = True
964
961
  self.abort_request(abort_all=True)
965
962
 
966
963
  async def continue_generation(self):
967
- async with self._cond:
968
- self._updating = False
969
- self._cond.notify_all()
964
+ async with self._is_updating_cond:
965
+ self._is_updating = False
966
+ self._is_updating_cond.notify_all()
970
967
 
971
968
  async def update_weights_from_disk(
972
969
  self,
@@ -1208,14 +1205,6 @@ class TokenizerManager:
1208
1205
  # Many DP ranks
1209
1206
  return [res.internal_state for res in responses]
1210
1207
 
1211
- async def get_load(self) -> dict:
1212
- # TODO(lsyin): fake load report server
1213
- if not self.current_load_lock.locked():
1214
- async with self.current_load_lock:
1215
- internal_state = await self.get_internal_state()
1216
- self.current_load = internal_state[0]["load"]
1217
- return {"load": self.current_load}
1218
-
1219
1208
  async def set_internal_state(
1220
1209
  self, obj: SetInternalStateReq
1221
1210
  ) -> SetInternalStateReqOutput:
@@ -1224,6 +1213,14 @@ class TokenizerManager:
1224
1213
  )
1225
1214
  return [res.internal_state for res in responses]
1226
1215
 
1216
+ async def get_load(self) -> dict:
1217
+ # TODO(lsyin): fake load report server
1218
+ if not self.current_load_lock.locked():
1219
+ async with self.current_load_lock:
1220
+ internal_state = await self.get_internal_state()
1221
+ self.current_load = internal_state[0]["load"]
1222
+ return {"load": self.current_load}
1223
+
1227
1224
  def get_log_request_metadata(self):
1228
1225
  max_length = None
1229
1226
  skip_names = None
@@ -1343,11 +1340,24 @@ class TokenizerManager:
1343
1340
  "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
1344
1341
  )
1345
1342
  return
1346
- logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
1347
- self.crash_dump_performed = True
1343
+
1348
1344
  if not self.crash_dump_folder:
1349
1345
  return
1350
1346
 
1347
+ logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
1348
+ self.crash_dump_performed = True
1349
+
1350
+ # Check if NFS directory is available
1351
+ # expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
1352
+ # use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
1353
+ # expected_nfs_dir, os.W_OK
1354
+ # )
1355
+ use_nfs_dir = False
1356
+ if not use_nfs_dir:
1357
+ logger.error(
1358
+ f"Expected NFS directory is not available or writable. Uploading to GCS."
1359
+ )
1360
+
1351
1361
  data_to_dump = []
1352
1362
  if self.crash_dump_request_list:
1353
1363
  data_to_dump.extend(self.crash_dump_request_list)
@@ -1357,7 +1367,12 @@ class TokenizerManager:
1357
1367
  for rid, state in self.rid_to_state.items():
1358
1368
  if not state.finished:
1359
1369
  unfinished_requests.append(
1360
- (state.obj, {}, state.created_time, time.time())
1370
+ (
1371
+ state.obj,
1372
+ state.out_list[-1] if state.out_list else {},
1373
+ state.created_time,
1374
+ time.time(),
1375
+ )
1361
1376
  )
1362
1377
  if unfinished_requests:
1363
1378
  data_to_dump.extend(unfinished_requests)
@@ -1365,10 +1380,11 @@ class TokenizerManager:
1365
1380
  if not data_to_dump:
1366
1381
  return
1367
1382
 
1383
+ object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
1368
1384
  filename = os.path.join(
1369
1385
  self.crash_dump_folder,
1370
1386
  os.getenv("HOSTNAME", None),
1371
- f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
1387
+ object_name,
1372
1388
  )
1373
1389
 
1374
1390
  os.makedirs(os.path.dirname(filename), exist_ok=True)
@@ -1383,6 +1399,24 @@ class TokenizerManager:
1383
1399
  f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
1384
1400
  )
1385
1401
 
1402
+ def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
1403
+ from google.cloud import storage
1404
+
1405
+ client = storage.Client()
1406
+ bucket = client.bucket(bucket_name)
1407
+ blob = bucket.blob(object_name)
1408
+ blob.upload_from_filename(source_file_path, if_generation_match=0)
1409
+ logger.error(
1410
+ f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
1411
+ )
1412
+
1413
+ if not use_nfs_dir:
1414
+ _upload_file_to_gcs(
1415
+ "sglang_crash_dump",
1416
+ filename,
1417
+ os.getenv("HOSTNAME", None) + "/" + object_name,
1418
+ )
1419
+
1386
1420
  async def sigterm_watchdog(self):
1387
1421
  while not self.gracefully_exit:
1388
1422
  await asyncio.sleep(5)
@@ -1426,7 +1460,7 @@ class TokenizerManager:
1426
1460
  while True:
1427
1461
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1428
1462
  self._result_dispatcher(recv_obj)
1429
- self.last_receive_tstamp = time.perf_counter()
1463
+ self.last_receive_tstamp = time.time()
1430
1464
 
1431
1465
  def _handle_batch_output(
1432
1466
  self,
@@ -1697,24 +1731,13 @@ class TokenizerManager:
1697
1731
  self.dump_requests_folder,
1698
1732
  datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
1699
1733
  )
1700
- logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
1701
-
1702
- to_dump = self.dump_request_list
1734
+ self._dump_data_to_file(
1735
+ data_list=self.dump_request_list,
1736
+ filename=filename,
1737
+ log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
1738
+ )
1703
1739
  self.dump_request_list = []
1704
1740
 
1705
- to_dump_with_server_args = {
1706
- "server_args": self.server_args,
1707
- "requests": to_dump,
1708
- }
1709
-
1710
- def background_task():
1711
- os.makedirs(self.dump_requests_folder, exist_ok=True)
1712
- with open(filename, "wb") as f:
1713
- pickle.dump(to_dump_with_server_args, f)
1714
-
1715
- # Schedule the task to run in the background without awaiting it
1716
- asyncio.create_task(asyncio.to_thread(background_task))
1717
-
1718
1741
  def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
1719
1742
  current_time = time.time()
1720
1743
  self.crash_dump_request_list.append(
@@ -1727,6 +1750,22 @@ class TokenizerManager:
1727
1750
  ):
1728
1751
  self.crash_dump_request_list.popleft()
1729
1752
 
1753
+ def _dump_data_to_file(
1754
+ self, data_list: List[Tuple], filename: str, log_message: str
1755
+ ):
1756
+ logger.info(log_message)
1757
+ to_dump_with_server_args = {
1758
+ "server_args": self.server_args,
1759
+ "requests": data_list.copy(),
1760
+ }
1761
+
1762
+ def background_task():
1763
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
1764
+ with open(filename, "wb") as f:
1765
+ pickle.dump(to_dump_with_server_args, f)
1766
+
1767
+ asyncio.create_task(asyncio.to_thread(background_task))
1768
+
1730
1769
  def _handle_abort_req(self, recv_obj):
1731
1770
  state = self.rid_to_state[recv_obj.rid]
1732
1771
  state.finished = True
@@ -1862,6 +1901,16 @@ class TokenizerManager:
1862
1901
  return scores
1863
1902
 
1864
1903
 
1904
+ def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
1905
+ is_cross_node = server_args.dist_init_addr
1906
+
1907
+ if is_cross_node:
1908
+ # Fallback to default CPU transport for multi-node
1909
+ return "default"
1910
+ else:
1911
+ return "cuda_ipc"
1912
+
1913
+
1865
1914
  async def print_exception_wrapper(func):
1866
1915
  """
1867
1916
  Sometimes an asyncio function does not print exception.
@@ -56,6 +56,7 @@ class TpModelWorker:
56
56
  server_args: ServerArgs,
57
57
  gpu_id: int,
58
58
  tp_rank: int,
59
+ moe_ep_rank: int,
59
60
  pp_rank: int,
60
61
  dp_rank: Optional[int],
61
62
  nccl_port: int,
@@ -66,6 +67,7 @@ class TpModelWorker:
66
67
  # Parse args
67
68
  self.tp_size = server_args.tp_size
68
69
  self.tp_rank = tp_rank
70
+ self.moe_ep_rank = moe_ep_rank
69
71
  self.pp_rank = pp_rank
70
72
 
71
73
  # Init model and tokenizer
@@ -85,6 +87,8 @@ class TpModelWorker:
85
87
  gpu_id=gpu_id,
86
88
  tp_rank=tp_rank,
87
89
  tp_size=server_args.tp_size,
90
+ moe_ep_rank=moe_ep_rank,
91
+ moe_ep_size=server_args.ep_size,
88
92
  pp_rank=pp_rank,
89
93
  pp_size=server_args.pp_size,
90
94
  nccl_port=nccl_port,
@@ -58,13 +58,14 @@ class TpModelWorkerClient:
58
58
  server_args: ServerArgs,
59
59
  gpu_id: int,
60
60
  tp_rank: int,
61
+ moe_ep_rank: int,
61
62
  pp_rank: int,
62
63
  dp_rank: Optional[int],
63
64
  nccl_port: int,
64
65
  ):
65
66
  # Load the model
66
67
  self.worker = TpModelWorker(
67
- server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
68
+ server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port
68
69
  )
69
70
  self.max_running_requests = self.worker.max_running_requests
70
71
  self.device = self.worker.device
@@ -2,7 +2,7 @@ import hashlib
2
2
  import logging
3
3
  import os
4
4
  from abc import ABC, abstractmethod
5
- from typing import List, Optional
5
+ from typing import Any, List, Optional
6
6
 
7
7
  import torch
8
8
 
@@ -39,7 +39,10 @@ class HiCacheStorage(ABC):
39
39
 
40
40
  @abstractmethod
41
41
  def get(
42
- self, key: str, target_location: Optional[torch.Tensor] = None
42
+ self,
43
+ key: str,
44
+ target_location: Optional[Any] = None,
45
+ target_sizes: Optional[Any] = None,
43
46
  ) -> torch.Tensor | None:
44
47
  """
45
48
  Retrieve the value associated with the given key.
@@ -49,7 +52,10 @@ class HiCacheStorage(ABC):
49
52
 
50
53
  @abstractmethod
51
54
  def batch_get(
52
- self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
55
+ self,
56
+ keys: List[str],
57
+ target_locations: Optional[Any] = None,
58
+ target_sizes: Optional[Any] = None,
53
59
  ) -> List[torch.Tensor | None]:
54
60
  """
55
61
  Retrieve values for multiple keys.
@@ -58,7 +64,13 @@ class HiCacheStorage(ABC):
58
64
  pass
59
65
 
60
66
  @abstractmethod
61
- def set(self, key, value) -> bool:
67
+ def set(
68
+ self,
69
+ key: str,
70
+ value: Optional[Any] = None,
71
+ target_location: Optional[Any] = None,
72
+ target_sizes: Optional[Any] = None,
73
+ ) -> bool:
62
74
  """
63
75
  Store the value associated with the given key.
64
76
  Returns True if the operation was successful, False otherwise.
@@ -66,7 +78,13 @@ class HiCacheStorage(ABC):
66
78
  pass
67
79
 
68
80
  @abstractmethod
69
- def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
81
+ def batch_set(
82
+ self,
83
+ keys: List[str],
84
+ values: Optional[Any] = None,
85
+ target_locations: Optional[Any] = None,
86
+ target_sizes: Optional[Any] = None,
87
+ ) -> bool:
70
88
  """
71
89
  Store multiple key-value pairs.
72
90
  Returns True if all operations were successful, False otherwise.
@@ -74,7 +92,7 @@ class HiCacheStorage(ABC):
74
92
  pass
75
93
 
76
94
  @abstractmethod
77
- def exists(self, key: str) -> bool:
95
+ def exists(self, key: str) -> bool | dict:
78
96
  """
79
97
  Check if the key exists in the storage.
80
98
  Returns True if the key exists, False otherwise.
@@ -85,7 +103,7 @@ class HiCacheStorage(ABC):
85
103
  class HiCacheFile(HiCacheStorage):
86
104
 
87
105
  def __init__(self, file_path: str = "/tmp/hicache"):
88
- self.file_path = file_path
106
+ self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
89
107
  tp_rank = get_tensor_model_parallel_rank()
90
108
  tp_size = get_tensor_model_parallel_world_size()
91
109
  self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
@@ -97,25 +115,38 @@ class HiCacheFile(HiCacheStorage):
97
115
  return key + self.tp_suffix
98
116
 
99
117
  def get(
100
- self, key: str, target_location: Optional[torch.Tensor] = None
118
+ self,
119
+ key: str,
120
+ target_location: Optional[Any] = None,
121
+ target_sizes: Optional[Any] = None,
101
122
  ) -> torch.Tensor | None:
102
123
  key = self._get_suffixed_key(key)
103
124
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
104
125
  try:
105
- # todo: fixing the target_location logic to enable in-place loading
106
- loaded_tensor = torch.load(tensor_path)
107
- if isinstance(loaded_tensor, torch.Tensor):
108
- return loaded_tensor
126
+ if target_location is not None:
127
+ # Load directly into target_location's memory buffer
128
+ with open(tensor_path, "rb") as f:
129
+ target_location.set_(
130
+ torch.frombuffer(f.read(), dtype=target_location.dtype)
131
+ .reshape(target_location.shape)
132
+ .storage()
133
+ )
134
+ return target_location
109
135
  else:
110
- logger.error(f"Loaded data for key {key} is not a tensor.")
111
- return None
136
+ loaded_tensor = torch.load(tensor_path)
137
+ if isinstance(loaded_tensor, torch.Tensor):
138
+ return loaded_tensor
139
+ else:
140
+ logger.error(f"Loaded data for key {key} is not a tensor.")
141
+ return None
112
142
  except FileNotFoundError:
113
143
  return None
114
144
 
115
145
  def batch_get(
116
146
  self,
117
147
  keys: List[str],
118
- target_locations: Optional[List[torch.Tensor]] = None,
148
+ target_locations: Optional[Any] = None,
149
+ target_sizes: Optional[Any] = None,
119
150
  ) -> List[torch.Tensor | None]:
120
151
  return [
121
152
  self.get(key, target_location)
@@ -124,7 +155,13 @@ class HiCacheFile(HiCacheStorage):
124
155
  )
125
156
  ]
126
157
 
127
- def set(self, key: str, value: torch.Tensor) -> bool:
158
+ def set(
159
+ self,
160
+ key: str,
161
+ value: Optional[Any] = None,
162
+ target_location: Optional[Any] = None,
163
+ target_sizes: Optional[Any] = None,
164
+ ) -> bool:
128
165
  key = self._get_suffixed_key(key)
129
166
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
130
167
  if self.exists(key):
@@ -137,7 +174,13 @@ class HiCacheFile(HiCacheStorage):
137
174
  logger.error(f"Failed to save tensor {key}: {e}")
138
175
  return False
139
176
 
140
- def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
177
+ def batch_set(
178
+ self,
179
+ keys: List[str],
180
+ values: Optional[Any] = None,
181
+ target_locations: Optional[Any] = None,
182
+ target_sizes: Optional[Any] = None,
183
+ ) -> bool:
141
184
  for key, value in zip(keys, values):
142
185
  if not self.set(key, value):
143
186
  return False