sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +12 -1
  3. sglang/srt/conversation.py +35 -1
  4. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  5. sglang/srt/entrypoints/http_server_engine.py +1 -1
  6. sglang/srt/layers/communicator.py +3 -1
  7. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  8. sglang/srt/layers/layernorm.py +2 -2
  9. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  10. sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
  11. sglang/srt/layers/moe/ep_moe/layer.py +140 -2
  12. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  13. sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
  14. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  15. sglang/srt/layers/quantization/__init__.py +2 -0
  16. sglang/srt/layers/quantization/fp8.py +28 -7
  17. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  18. sglang/srt/layers/quantization/w4afp8.py +264 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  20. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  21. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  22. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  23. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  24. sglang/srt/managers/cache_controller.py +41 -195
  25. sglang/srt/managers/io_struct.py +8 -1
  26. sglang/srt/managers/mm_utils.py +4 -2
  27. sglang/srt/managers/schedule_batch.py +1 -1
  28. sglang/srt/managers/scheduler.py +17 -5
  29. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  30. sglang/srt/mem_cache/memory_pool.py +113 -63
  31. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  32. sglang/srt/mem_cache/radix_cache.py +8 -4
  33. sglang/srt/models/deepseek_v2.py +16 -2
  34. sglang/srt/models/mllama4.py +360 -79
  35. sglang/srt/multimodal/mm_utils.py +2 -2
  36. sglang/srt/multimodal/processors/mllama4.py +62 -60
  37. sglang/srt/server_args.py +15 -0
  38. sglang/srt/two_batch_overlap.py +3 -0
  39. sglang/srt/utils.py +37 -17
  40. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  41. sglang/utils.py +5 -5
  42. sglang/version.py +1 -1
  43. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
  44. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
  45. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  46. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  47. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- import concurrent.futures
17
16
  import logging
18
17
  import math
19
18
  import threading
@@ -169,12 +168,23 @@ class HiCacheController:
169
168
  page_size: int,
170
169
  load_cache_event: threading.Event = None,
171
170
  write_policy: str = "write_through_selective",
171
+ io_backend: str = "",
172
172
  ):
173
173
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
174
174
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
175
175
  self.mem_pool_host = mem_pool_host
176
176
  self.write_policy = write_policy
177
177
  self.page_size = page_size
178
+ # using kernel for small page KV cache transfer and DMA for large pages
179
+ if not io_backend:
180
+ IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
181
+ self.io_backend = (
182
+ "direct"
183
+ if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
184
+ else "kernel"
185
+ )
186
+ else:
187
+ self.io_backend = io_backend
178
188
 
179
189
  self.load_cache_event = load_cache_event
180
190
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -203,12 +213,7 @@ class HiCacheController:
203
213
  self.load_stream = torch.cuda.Stream()
204
214
 
205
215
  self.write_thread = threading.Thread(
206
- target=(
207
- self.write_thread_func_buffer
208
- if self.page_size == 1
209
- else self.write_thread_func_direct
210
- ),
211
- daemon=True,
216
+ target=self.write_thread_func_direct, daemon=True
212
217
  )
213
218
  self.load_thread = threading.Thread(
214
219
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -229,12 +234,7 @@ class HiCacheController:
229
234
  self.ack_load_queue.queue.clear()
230
235
 
231
236
  self.write_thread = threading.Thread(
232
- target=(
233
- self.write_thread_func_buffer
234
- if self.page_size == 1
235
- else self.write_thread_func_direct
236
- ),
237
- daemon=True,
237
+ target=self.write_thread_func_direct, daemon=True
238
238
  )
239
239
  self.load_thread = threading.Thread(
240
240
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -281,6 +281,15 @@ class HiCacheController:
281
281
  )
282
282
  return device_indices
283
283
 
284
+ def move_indices(self, host_indices, device_indices):
285
+ # move indices to GPU if using kernels, to host if using direct indexing
286
+ if self.io_backend == "kernel":
287
+ return host_indices.to(self.mem_pool_device.device), device_indices
288
+ elif self.io_backend == "direct":
289
+ return host_indices, device_indices.cpu()
290
+ else:
291
+ raise ValueError(f"Unsupported io backend")
292
+
284
293
  def write_thread_func_direct(self):
285
294
  """
286
295
  Directly write through KV caches to host memory without buffering.
@@ -289,10 +298,14 @@ class HiCacheController:
289
298
  while not self.stop_event.is_set():
290
299
  try:
291
300
  operation = self.write_queue.get(block=True, timeout=1)
292
- self.mem_pool_host.write_page_all_layers(
293
- operation.host_indices,
294
- operation.device_indices,
295
- self.mem_pool_device,
301
+ host_indices, device_indices = self.move_indices(
302
+ operation.host_indices, operation.device_indices
303
+ )
304
+ self.mem_pool_device.backup_to_host_all_layer(
305
+ self.mem_pool_host,
306
+ host_indices,
307
+ device_indices,
308
+ self.io_backend,
296
309
  )
297
310
  self.write_stream.synchronize()
298
311
  self.mem_pool_host.complete_io(operation.host_indices)
@@ -304,27 +317,6 @@ class HiCacheController:
304
317
  except Exception as e:
305
318
  logger.error(e)
306
319
 
307
- def load_thread_func_direct(self):
308
- """
309
- Directly load KV caches from host memory to device memory without buffering.
310
- """
311
- torch.cuda.set_stream(self.load_stream)
312
- while not self.stop_event.is_set():
313
- try:
314
- operation = self.load_queue.get(block=True, timeout=1)
315
- operation.data = self.mem_pool_host.get_flat_data(
316
- operation.host_indices
317
- )
318
- self.mem_pool_device.transfer(operation.device_indices, operation.data)
319
- self.mem_pool_host.complete_io(operation.host_indices)
320
- for node_id in operation.node_ids:
321
- if node_id != 0:
322
- self.ack_load_queue.put(node_id)
323
- except Empty:
324
- continue
325
- except Exception as e:
326
- logger.error(e)
327
-
328
320
  def load_thread_func_layer_by_layer(self):
329
321
  """
330
322
  Load KV caches from host memory to device memory layer by layer.
@@ -349,22 +341,18 @@ class HiCacheController:
349
341
 
350
342
  # start layer-wise KV cache transfer from CPU to GPU
351
343
  self.layer_done_counter.reset()
344
+ host_indices, device_indices = self.move_indices(
345
+ batch_operation.host_indices, batch_operation.device_indices
346
+ )
352
347
  for i in range(self.mem_pool_host.layer_num):
353
- if self.page_size == 1:
354
- flat_data = self.mem_pool_host.get_flat_data_by_layer(
355
- batch_operation.host_indices, i
356
- )
357
- self.mem_pool_device.transfer_per_layer(
358
- batch_operation.device_indices, flat_data, i
359
- )
360
- else:
361
- self.mem_pool_host.load_page_per_layer(
362
- batch_operation.host_indices,
363
- batch_operation.device_indices,
364
- self.mem_pool_device,
365
- i,
366
- )
367
- self.load_stream.synchronize()
348
+ self.mem_pool_device.load_from_host_per_layer(
349
+ self.mem_pool_host,
350
+ host_indices,
351
+ device_indices,
352
+ i,
353
+ self.io_backend,
354
+ )
355
+ self.load_stream.synchronize()
368
356
  self.layer_done_counter.increment()
369
357
 
370
358
  self.mem_pool_host.complete_io(batch_operation.host_indices)
@@ -372,148 +360,6 @@ class HiCacheController:
372
360
  if node_id != 0:
373
361
  self.ack_load_queue.put(node_id)
374
362
 
375
- def write_aux_func(self, no_wait=False):
376
- """
377
- Auxiliary function to prepare the buffer for write operations.
378
- """
379
- torch.cuda.set_stream(self.write_stream)
380
-
381
- def _to_op(op_):
382
- assert op_.device_indices.is_cuda, "Device indices should be on GPU"
383
- op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
384
- self.mem_pool_host.device
385
- )
386
- self.write_buffer.put(op_)
387
- return op_
388
-
389
- buffer = None
390
- while not self.stop_event.is_set():
391
- try:
392
- operation = self.write_queue.get(block=True, timeout=1)
393
- factor = (
394
- len(operation.device_indices) // self.write_buffer.max_buffer_size
395
- )
396
-
397
- if factor >= 1:
398
- if buffer is not None:
399
- _to_op(buffer)
400
- buffer = None
401
-
402
- if factor < 2:
403
- _to_op(operation)
404
- else:
405
- split_ops = operation.split(factor)
406
- for op_ in split_ops:
407
- _to_op(op_)
408
- continue
409
-
410
- if buffer is None:
411
- buffer = operation
412
- else:
413
- buffer.merge(operation)
414
- if (
415
- no_wait
416
- or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
417
- or self.write_queue.empty()
418
- or self.write_buffer.empty()
419
- ):
420
- _to_op(buffer)
421
- buffer = None
422
- except Empty:
423
- continue
424
- except Exception as e:
425
- logger.error(e)
426
-
427
- def load_aux_func(self):
428
- """
429
- Auxiliary function to prepare the buffer for load operations.
430
- """
431
-
432
- def _pin_op(op_, put=True):
433
- op_.data = (
434
- self.mem_pool_host.get_flat_data(op_.host_indices)
435
- .contiguous()
436
- .pin_memory()
437
- )
438
- if put:
439
- self.load_buffer.put(op_)
440
- return op_
441
-
442
- buffer = None
443
- while not self.stop_event.is_set():
444
- try:
445
- operation = self.load_queue.get(block=True, timeout=1)
446
- factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
447
-
448
- if factor >= 1:
449
- if buffer is not None:
450
- _pin_op(buffer)
451
- buffer = None
452
-
453
- if factor < 2:
454
- _pin_op(operation)
455
- else:
456
- split_ops = operation.split(factor)
457
- split_args = [(op_, True) for op_ in split_ops[:-1]]
458
- split_args.append((split_ops[-1], False))
459
- # Spawn threads to pin each op concurrently
460
- with concurrent.futures.ThreadPoolExecutor() as executor:
461
- pinned_ops = list(
462
- executor.map(
463
- lambda x: _pin_op(x[0], put=x[1]), split_args
464
- )
465
- )
466
- # preserve the order of last op to ensure correct ack
467
- self.load_buffer.put(pinned_ops[-1])
468
- continue
469
-
470
- if buffer is None:
471
- buffer = operation
472
- else:
473
- buffer.merge(operation)
474
- if (
475
- len(buffer.host_indices) >= self.load_buffer.max_buffer_size
476
- or self.load_queue.empty()
477
- or self.load_buffer.empty()
478
- ):
479
- _pin_op(buffer)
480
- buffer = None
481
- except Empty:
482
- continue
483
- except Exception as e:
484
- logger.error(e)
485
-
486
- # todo (zhiqiang): double buffering to be deprecated
487
- def write_thread_func_buffer(self):
488
- aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
489
- aux_thread.start()
490
-
491
- while not self.stop_event.is_set():
492
- operation = self.write_buffer.get()
493
- if operation is None:
494
- continue
495
- self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
496
- self.mem_pool_host.complete_io(operation.host_indices)
497
- for node_id in operation.node_ids:
498
- if node_id != 0:
499
- self.ack_write_queue.put(node_id)
500
- aux_thread.join()
501
-
502
- def load_thread_func_buffer(self):
503
- torch.cuda.set_stream(self.load_stream)
504
- aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
505
- aux_thread.start()
506
- while not self.stop_event.is_set():
507
- operation = self.load_buffer.get()
508
- if operation is None:
509
- continue
510
- self.mem_pool_device.transfer(operation.device_indices, operation.data)
511
- self.mem_pool_host.complete_io(operation.host_indices)
512
- for node_id in operation.node_ids:
513
- if node_id != 0:
514
- self.ack_load_queue.put(node_id)
515
- aux_thread.join()
516
-
517
363
  def evict_device(
518
364
  self, device_indices: torch.Tensor, host_indices: torch.Tensor
519
365
  ) -> int:
@@ -200,6 +200,8 @@ class GenerateReqInput:
200
200
  self.text = [self.text]
201
201
  if self.input_ids is not None:
202
202
  self.input_ids = [self.input_ids]
203
+ if self.input_embeds is not None:
204
+ self.input_embeds = [self.input_embeds]
203
205
 
204
206
  def _normalize_single_inputs(self):
205
207
  """Normalize inputs for a single example."""
@@ -324,7 +326,9 @@ class GenerateReqInput:
324
326
  new_rids = [f"{self.rid}_{i}" for i in range(num)]
325
327
  self.rid = new_rids
326
328
  elif isinstance(self.rid, list):
327
- if len(self.rid) != num:
329
+ # Note: the length of rid shall be the same as the batch_size,
330
+ # as the rid would be expanded for parallel sampling in tokenizer_manager
331
+ if len(self.rid) != self.batch_size:
328
332
  raise ValueError(
329
333
  "The specified rids length mismatch with the batch_size for batch processing."
330
334
  )
@@ -400,6 +404,9 @@ class GenerateReqInput:
400
404
  return GenerateReqInput(
401
405
  text=self.text[i] if self.text is not None else None,
402
406
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
407
+ input_embeds=(
408
+ self.input_embeds[i] if self.input_embeds is not None else None
409
+ ),
403
410
  image_data=self.image_data[i],
404
411
  audio_data=self.audio_data[i],
405
412
  sampling_params=self.sampling_params[i],
@@ -248,7 +248,9 @@ def _get_chunked_prefill_embedding(
248
248
  ) -> Optional[torch.Tensor]:
249
249
  # Calculate embedding for each request, try to get it from cache to avoid repeated calculation
250
250
  embedding_list = []
251
- for i in range(len(items_size) - 1):
251
+ # FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length)
252
+ max_iterations = min(len(items_size) - 1, len(prefix_length))
253
+ for i in range(max_iterations):
252
254
  if items_size[i] == items_size[i + 1]:
253
255
  continue
254
256
  embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
@@ -269,7 +271,7 @@ def _get_chunked_prefill_embedding(
269
271
  embedding_per_req_chunk, _, end_index = get_embedding_chunk(
270
272
  embedding=embedding_per_req,
271
273
  extend_prefix_len=prefix_length[i],
272
- extend_seq_len=extend_length[i],
274
+ extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
273
275
  items_offset=items_offset,
274
276
  )
275
277
  # remove this item from cache if chunk reaches to the end
@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
101
101
  "triton_attention_reduce_in_fp32",
102
102
  "num_reserved_decode_tokens",
103
103
  "weight_loader_disable_mmap",
104
+ "enable_triton_kernel_moe",
104
105
  ]
105
106
 
106
107
  # Put some global args for easy access
@@ -842,7 +843,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
842
843
  global_num_tokens_for_logprob: Optional[List[int]] = None
843
844
  is_extend_in_batch: bool = False
844
845
  can_run_dp_cuda_graph: bool = False
845
- is_extend_in_batch: bool = False
846
846
  tbo_split_seq_index: Optional[int] = None
847
847
  global_forward_mode: Optional[ForwardMode] = None
848
848
 
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """A scheduler that manages a tensor parallel GPU worker."""
15
15
 
16
+ import datetime
16
17
  import faulthandler
17
18
  import logging
18
19
  import os
@@ -590,6 +591,12 @@ class Scheduler(
590
591
  hicache_ratio=server_args.hicache_ratio,
591
592
  hicache_size=server_args.hicache_size,
592
593
  hicache_write_policy=server_args.hicache_write_policy,
594
+ hicache_io_backend=(
595
+ "direct"
596
+ if server_args.attention_backend
597
+ == "fa3" # hot fix for incompatibility
598
+ else server_args.hicache_io_backend
599
+ ),
593
600
  )
594
601
  self.tp_worker.register_hicache_layer_transfer_counter(
595
602
  self.tree_cache.cache_controller.layer_done_counter
@@ -1313,10 +1320,12 @@ class Scheduler(
1313
1320
  f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
1314
1321
  f += f"#queue-req: {len(self.waiting_queue)}, "
1315
1322
  f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
1316
- f += f"input throughput (token/s): {self.last_input_throughput:.2f} "
1323
+ f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
1317
1324
  else:
1318
1325
  f += f"#running-req: {running_bs}, "
1319
- f += f"#queue-req: {len(self.waiting_queue)}"
1326
+ f += f"#queue-req: {len(self.waiting_queue)}, "
1327
+
1328
+ f += f"timestamp: {datetime.datetime.now().isoformat()}"
1320
1329
 
1321
1330
  logger.info(f)
1322
1331
 
@@ -1378,7 +1387,8 @@ class Scheduler(
1378
1387
  msg += (
1379
1388
  f"cuda graph: {can_run_cuda_graph}, "
1380
1389
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1381
- f"#queue-req: {len(self.waiting_queue)}"
1390
+ f"#queue-req: {len(self.waiting_queue)}, "
1391
+ f"timestamp: {datetime.datetime.now().isoformat()}"
1382
1392
  )
1383
1393
 
1384
1394
  logger.info(msg)
@@ -2333,9 +2343,8 @@ class Scheduler(
2333
2343
 
2334
2344
  def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2335
2345
  tags = recv_req.tags
2336
- import subprocess
2337
2346
 
2338
- if tags is None:
2347
+ if tags is None or len(tags) == 0:
2339
2348
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2340
2349
 
2341
2350
  if GPU_MEMORY_TYPE_KV_CACHE in tags:
@@ -2346,17 +2355,20 @@ class Scheduler(
2346
2355
  self.stashed_model_static_state = _export_static_state(
2347
2356
  self.tp_worker.worker.model_runner.model
2348
2357
  )
2358
+ torch.distributed.barrier(self.tp_cpu_group)
2349
2359
  self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
2350
2360
 
2351
2361
  return ReleaseMemoryOccupationReqOutput()
2352
2362
 
2353
2363
  def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2354
2364
  tags = recv_req.tags
2365
+
2355
2366
  if tags is None or len(tags) == 0:
2356
2367
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2357
2368
 
2358
2369
  if GPU_MEMORY_TYPE_WEIGHTS in tags:
2359
2370
  self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2371
+ torch.distributed.barrier(self.tp_cpu_group)
2360
2372
  _import_static_state(
2361
2373
  self.tp_worker.worker.model_runner.model,
2362
2374
  self.stashed_model_static_state,
@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache):
34
34
  hicache_ratio: float,
35
35
  hicache_size: int,
36
36
  hicache_write_policy: str,
37
+ hicache_io_backend: str,
37
38
  ):
38
39
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
39
40
  if isinstance(self.kv_cache, MHATokenToKVPool):
@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache):
56
57
  page_size,
57
58
  load_cache_event=self.load_cache_event,
58
59
  write_policy=hicache_write_policy,
60
+ io_backend=hicache_io_backend,
59
61
  )
60
62
 
61
63
  # record the nodes with ongoing write through