sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -35,11 +35,17 @@ from sglang.srt.speculative.eagle_utils import (
35
35
  EagleVerifyInput,
36
36
  EagleVerifyOutput,
37
37
  assign_draft_cache_locs,
38
+ fast_topk,
38
39
  generate_token_bitmask,
39
40
  select_top_k_tokens,
40
41
  )
41
42
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
42
- from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
43
+ from sglang.srt.utils import (
44
+ empty_context,
45
+ get_available_gpu_memory,
46
+ is_cuda,
47
+ next_power_of_2,
48
+ )
43
49
 
44
50
  if is_cuda():
45
51
  from sgl_kernel import segment_packbits
@@ -152,6 +158,12 @@ class EAGLEWorker(TpModelWorker):
152
158
  self.init_attention_backend()
153
159
  self.init_cuda_graphs()
154
160
 
161
+ # Some dummy tensors
162
+ self.num_new_pages_per_topk = torch.empty(
163
+ (), dtype=torch.int64, device=self.device
164
+ )
165
+ self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
166
+
155
167
  def init_attention_backend(self):
156
168
  # Create multi-step attn backends and cuda graph runners
157
169
  if self.server_args.attention_backend == "flashinfer":
@@ -254,7 +266,7 @@ class EAGLEWorker(TpModelWorker):
254
266
  self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
255
267
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
256
268
  logger.info(
257
- f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
269
+ f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
258
270
  )
259
271
 
260
272
  # Capture extend
@@ -269,7 +281,7 @@ class EAGLEWorker(TpModelWorker):
269
281
  )
270
282
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
271
283
  logger.info(
272
- f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
284
+ f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
273
285
  )
274
286
 
275
287
  @property
@@ -365,14 +377,21 @@ class EAGLEWorker(TpModelWorker):
365
377
  )
366
378
 
367
379
  # Allocate cache locations
380
+ # Layout of the out_cache_loc
381
+ # [ topk 0 ] [ topk 1 ]
382
+ # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
368
383
  if self.page_size == 1:
369
384
  out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
370
- num_seqs * self.topk * self.speculative_num_steps, backup_state=True
385
+ num_seqs * self.speculative_num_steps * self.topk, backup_state=True
371
386
  )
372
387
  else:
373
388
  if self.topk == 1:
374
- prefix_lens = batch.seq_lens
375
- seq_lens = prefix_lens + self.speculative_num_steps
389
+ prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
390
+ batch.req_to_token_pool.req_to_token,
391
+ batch.req_pool_indices,
392
+ batch.seq_lens,
393
+ self.speculative_num_steps,
394
+ )
376
395
  extend_num_tokens = num_seqs * self.speculative_num_steps
377
396
  else:
378
397
  # In this case, the last partial page needs to be duplicated.
@@ -385,29 +404,33 @@ class EAGLEWorker(TpModelWorker):
385
404
  # "x" means speculative draft tokens
386
405
  # "." means padded tokens
387
406
 
388
- # TODO: fuse these ops
389
- prefix_lens = batch.seq_lens
390
- last_page_lens = prefix_lens % self.page_size
391
- num_new_pages = (
392
- last_page_lens + self.speculative_num_steps + self.page_size - 1
393
- ) // self.page_size
394
- seq_lens = (
395
- prefix_lens // self.page_size * self.page_size
396
- + num_new_pages * (self.page_size * self.topk)
397
- )
398
- extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
399
- raise NotImplementedError(
400
- "page_size > 1 and top_k > 1 are not supported."
407
+ # TODO(lmzheng): The current implementation is still a fake support
408
+ # for page size > 1. In the `assign_draft_cache_locs` below,
409
+ # we directly move the indices instead of the real kv cache.
410
+ # This only works when the kernel backend runs with page size = 1.
411
+ # If the kernel backend runs with page size > 1, we need to
412
+ # duplicate the real KV cache. The overhead of duplicating KV
413
+ # cache seems okay because the draft KV cache only has one layer.
414
+ # see a related copy operation in MHATokenToKVPool::move_kv_cache.
415
+
416
+ (
417
+ prefix_lens,
418
+ seq_lens,
419
+ last_loc,
420
+ self.num_new_pages_per_topk,
421
+ self.extend_lens,
422
+ ) = get_last_loc_large_page_size_large_top_k(
423
+ batch.req_to_token_pool.req_to_token,
424
+ batch.req_pool_indices,
425
+ batch.seq_lens,
426
+ self.speculative_num_steps,
427
+ self.topk,
428
+ self.page_size,
401
429
  )
402
- # TODO: Support page_size > 1 and top_k > 1
403
- # 1. Duplicate the KV cache in the last partial page for all top-k segments
404
- # 2. Modify generate_draft_decode_kv_indices accordingly
405
-
406
- last_loc = get_last_loc(
407
- batch.req_to_token_pool.req_to_token,
408
- batch.req_pool_indices,
409
- prefix_lens,
410
- )
430
+
431
+ # TODO(lmzheng): remove this device sync
432
+ extend_num_tokens = torch.sum(self.extend_lens).item()
433
+
411
434
  out_cache_loc, token_to_kv_pool_state_backup = (
412
435
  batch.alloc_paged_token_slots_extend(
413
436
  prefix_lens,
@@ -422,18 +445,30 @@ class EAGLEWorker(TpModelWorker):
422
445
  batch.req_pool_indices,
423
446
  batch.req_to_token_pool.req_to_token,
424
447
  batch.seq_lens,
448
+ self.extend_lens,
449
+ self.num_new_pages_per_topk,
425
450
  out_cache_loc,
426
451
  batch.req_to_token_pool.req_to_token.shape[1],
427
452
  self.topk,
428
453
  self.speculative_num_steps,
429
454
  self.page_size,
455
+ next_power_of_2(num_seqs),
456
+ next_power_of_2(self.speculative_num_steps),
430
457
  )
458
+
459
+ if self.page_size > 1 and self.topk > 1:
460
+ # Remove padded slots
461
+ out_cache_loc = out_cache_loc[
462
+ : num_seqs * self.topk * self.speculative_num_steps
463
+ ]
464
+
431
465
  batch.out_cache_loc = out_cache_loc
432
466
  batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
467
+ batch.return_hidden_states = False
433
468
  spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
469
+ spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
434
470
 
435
471
  # Get forward batch
436
- spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
437
472
  model_worker_batch = batch.get_model_worker_batch()
438
473
  forward_batch = ForwardBatch.init_new(
439
474
  model_worker_batch, self.draft_model_runner
@@ -448,9 +483,6 @@ class EAGLEWorker(TpModelWorker):
448
483
  else:
449
484
  # Initialize attention backend
450
485
  self.draft_attn_backend.init_forward_metadata(forward_batch)
451
- forward_batch = ForwardBatch.init_new(
452
- model_worker_batch, self.draft_model_runner
453
- )
454
486
  # Run forward steps
455
487
  score_list, token_list, parents_list = self.draft_forward(forward_batch)
456
488
 
@@ -503,6 +535,13 @@ class EAGLEWorker(TpModelWorker):
503
535
  if self.hot_token_id is not None:
504
536
  topk_index = self.hot_token_id[topk_index]
505
537
 
538
+ out_cache_loc = out_cache_loc.reshape(
539
+ forward_batch.batch_size, self.topk, self.speculative_num_steps
540
+ )
541
+ out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
542
+ self.speculative_num_steps, -1
543
+ )
544
+
506
545
  # Return values
507
546
  score_list: List[torch.Tensor] = []
508
547
  token_list: List[torch.Tensor] = []
@@ -524,10 +563,7 @@ class EAGLEWorker(TpModelWorker):
524
563
 
525
564
  # Set inputs
526
565
  forward_batch.input_ids = input_ids
527
- out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
528
- forward_batch.out_cache_loc = out_cache_loc[
529
- :, self.topk * i : self.topk * (i + 1)
530
- ].flatten()
566
+ forward_batch.out_cache_loc = out_cache_loc[i]
531
567
  forward_batch.positions.add_(1)
532
568
  forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
533
569
  spec_info.hidden_states = hidden_states
@@ -547,11 +583,13 @@ class EAGLEWorker(TpModelWorker):
547
583
 
548
584
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
549
585
  spec_info.prepare_for_verify(batch, self.page_size)
586
+ batch.return_hidden_states = False
550
587
  batch.forward_mode = ForwardMode.TARGET_VERIFY
551
588
  batch.spec_info = spec_info
552
589
  model_worker_batch = batch.get_model_worker_batch(
553
590
  seq_lens_cpu_cache=spec_info.seq_lens_cpu
554
591
  )
592
+ assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
555
593
 
556
594
  if batch.has_grammar:
557
595
  retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
@@ -583,7 +621,7 @@ class EAGLEWorker(TpModelWorker):
583
621
  if vocab_mask is not None:
584
622
  assert spec_info.grammar is not None
585
623
  vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
586
- # otherwise, this vocab mask will be the one from the previous extend stage
624
+ # NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
587
625
  # and will be applied to produce wrong results
588
626
  batch.sampling_info.vocab_mask = None
589
627
 
@@ -604,13 +642,13 @@ class EAGLEWorker(TpModelWorker):
604
642
  ]
605
643
  logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
606
644
 
645
+ if batch.return_logprob:
646
+ self.add_logprob_values(batch, res, logits_output)
647
+
607
648
  # Prepare the batch for the next draft forwards.
608
649
  batch.forward_mode = ForwardMode.DECODE
609
650
  batch.spec_info = res.draft_input
610
651
 
611
- if batch.return_logprob:
612
- self.add_logprob_values(batch, res, logits_output)
613
-
614
652
  return logits_output, res, model_worker_batch, can_run_cuda_graph
615
653
 
616
654
  def add_logprob_values(
@@ -623,8 +661,16 @@ class EAGLEWorker(TpModelWorker):
623
661
  logits_output = res.logits_output
624
662
  top_logprobs_nums = batch.top_logprobs_nums
625
663
  token_ids_logprobs = batch.token_ids_logprobs
664
+ accepted_indices = res.accepted_indices
665
+ assert len(accepted_indices) == len(logits_output.next_token_logits)
666
+ temperatures = batch.sampling_info.temperatures
667
+ num_draft_tokens = batch.spec_info.draft_token_num
668
+ # acceptance indices are the indices in a "flattened" batch.
669
+ # dividing it to num_draft_tokens will yield the actual batch index.
670
+ temperatures = temperatures[accepted_indices // num_draft_tokens]
671
+
626
672
  logprobs = torch.nn.functional.log_softmax(
627
- logits_output.next_token_logits, dim=-1
673
+ logits_output.next_token_logits / temperatures, dim=-1
628
674
  )
629
675
  batch_next_token_ids = res.verified_id
630
676
  num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
@@ -659,7 +705,7 @@ class EAGLEWorker(TpModelWorker):
659
705
  pt = 0
660
706
  next_token_logprobs = logits_output.next_token_logprobs.tolist()
661
707
  verified_ids = batch_next_token_ids.tolist()
662
- for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
708
+ for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
663
709
  for _ in range(num_tokens):
664
710
  if req.return_logprob:
665
711
  req.output_token_logprobs_val.append(next_token_logprobs[pt])
@@ -691,6 +737,7 @@ class EAGLEWorker(TpModelWorker):
691
737
  hidden_states=hidden_states,
692
738
  verified_id=next_token_ids,
693
739
  )
740
+ batch.return_hidden_states = False
694
741
  batch.spec_info.prepare_for_extend(batch)
695
742
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
696
743
  model_worker_batch = batch.get_model_worker_batch(
@@ -781,4 +828,48 @@ def load_token_map(token_map_path: str) -> List[int]:
781
828
  )
782
829
  token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
783
830
  hot_token_id = torch.load(token_map_path, weights_only=True)
784
- return torch.tensor(hot_token_id, dtype=torch.int32)
831
+ return torch.tensor(hot_token_id, dtype=torch.int64)
832
+
833
+
834
+ @torch.compile(dynamic=True)
835
+ def get_last_loc_large_page_size_top_k_1(
836
+ req_to_token: torch.Tensor,
837
+ req_pool_indices: torch.Tensor,
838
+ seq_lens,
839
+ speculative_num_steps: int,
840
+ ):
841
+ prefix_lens = seq_lens
842
+ seq_lens = prefix_lens + speculative_num_steps
843
+ last_loc = get_last_loc(
844
+ req_to_token,
845
+ req_pool_indices,
846
+ prefix_lens,
847
+ )
848
+ return prefix_lens, seq_lens, last_loc
849
+
850
+
851
+ @torch.compile(dynamic=True)
852
+ def get_last_loc_large_page_size_large_top_k(
853
+ req_to_token: torch.Tensor,
854
+ req_pool_indices: torch.Tensor,
855
+ seq_lens: torch.Tensor,
856
+ speculative_num_steps: int,
857
+ topk: int,
858
+ page_size: int,
859
+ ):
860
+ prefix_lens = seq_lens
861
+ last_page_lens = prefix_lens % page_size
862
+ num_new_pages_per_topk = (
863
+ last_page_lens + speculative_num_steps + page_size - 1
864
+ ) // page_size
865
+ seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
866
+ page_size * topk
867
+ )
868
+ extend_lens = seq_lens - prefix_lens
869
+ last_loc = get_last_loc(
870
+ req_to_token,
871
+ req_pool_indices,
872
+ prefix_lens,
873
+ )
874
+
875
+ return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import (
11
11
  ScatterMode,
12
12
  )
13
13
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
14
- from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
14
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
15
15
  from sglang.srt.managers.schedule_batch import global_server_args_dict
16
16
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
17
17
  from sglang.srt.operations import execute_operations, execute_overlapped_operations
@@ -479,7 +479,9 @@ def _model_forward_tbo(
479
479
  )
480
480
  del inputs
481
481
 
482
- with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
482
+ with deep_gemm_wrapper.configure_deep_gemm_num_sms(
483
+ operations_strategy.deep_gemm_num_sms
484
+ ):
483
485
  outputs_arr = execute_overlapped_operations(
484
486
  inputs_arr=inputs_arr,
485
487
  operations_arr=[operations_strategy.operations] * 2,
sglang/srt/utils.py CHANGED
@@ -17,6 +17,7 @@ import base64
17
17
  import builtins
18
18
  import ctypes
19
19
  import dataclasses
20
+ import functools
20
21
  import importlib
21
22
  import io
22
23
  import ipaddress
@@ -837,6 +838,7 @@ class CustomCacheManager(FileCacheManager):
837
838
 
838
839
 
839
840
  def set_ulimit(target_soft_limit=65535):
841
+ # number of open files
840
842
  resource_type = resource.RLIMIT_NOFILE
841
843
  current_soft, current_hard = resource.getrlimit(resource_type)
842
844
 
@@ -846,6 +848,18 @@ def set_ulimit(target_soft_limit=65535):
846
848
  except ValueError as e:
847
849
  logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
848
850
 
851
+ # stack size
852
+ resource_type = resource.RLIMIT_STACK
853
+ current_soft, current_hard = resource.getrlimit(resource_type)
854
+ target_soft_limit_stack_size = 1024 * target_soft_limit
855
+ if current_soft < target_soft_limit_stack_size:
856
+ try:
857
+ resource.setrlimit(
858
+ resource_type, (target_soft_limit_stack_size, current_hard)
859
+ )
860
+ except ValueError as e:
861
+ logger.warning(f"Fail to set RLIMIT_STACK: {e}")
862
+
849
863
 
850
864
  def add_api_key_middleware(app, api_key: str):
851
865
  @app.middleware("http")
@@ -1373,6 +1387,11 @@ def print_warning_once(msg: str) -> None:
1373
1387
  logger.warning(msg, stacklevel=2)
1374
1388
 
1375
1389
 
1390
+ @functools.lru_cache(None)
1391
+ def print_info_once(msg: str) -> None:
1392
+ logger.info(msg)
1393
+
1394
+
1376
1395
  def get_device_name(device_id: int = 0) -> str:
1377
1396
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1378
1397
  return torch.cuda.get_device_name(device_id)
@@ -2197,6 +2216,45 @@ class Withable(Generic[T]):
2197
2216
  self._value = None
2198
2217
 
2199
2218
 
2219
+ def merge_bias_tensor(
2220
+ lhs: Optional[torch.Tensor],
2221
+ rhs: Optional[torch.Tensor],
2222
+ bs1: int,
2223
+ bs2: int,
2224
+ device: str,
2225
+ default: float,
2226
+ ):
2227
+ """Merge two bias tensors for batch merging.
2228
+
2229
+ Args:
2230
+ lhs: Left-hand side tensor
2231
+ rhs: Right-hand side tensor
2232
+ bs1: Batch size of left-hand side tensor
2233
+ bs2: Batch size of right-hand side tensor
2234
+ device: Device to place the merged tensor on
2235
+ default: Default value for missing tensor elements
2236
+
2237
+ Returns:
2238
+ Merged tensor or None if both inputs are None
2239
+ """
2240
+ if lhs is None and rhs is None:
2241
+ return None
2242
+
2243
+ if lhs is not None and rhs is not None:
2244
+ return torch.cat([lhs, rhs])
2245
+ else:
2246
+ if lhs is not None:
2247
+ shape, dtype = lhs.shape[1:], lhs.dtype
2248
+ else:
2249
+ shape, dtype = rhs.shape[1:], rhs.dtype
2250
+
2251
+ if lhs is None:
2252
+ lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
2253
+ if rhs is None:
2254
+ rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
2255
+ return torch.cat([lhs, rhs])
2256
+
2257
+
2200
2258
  def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
2201
2259
  import huggingface_hub as hf
2202
2260
 
@@ -2,6 +2,8 @@ import unittest
2
2
 
3
3
  import torch
4
4
 
5
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
6
+ from sglang.srt.layers.radix_attention import RadixAttention
5
7
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
6
8
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
7
9
  from sglang.test.test_utils import CustomTestCase
sglang/test/runners.py CHANGED
@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
42
42
  # the output of gemma-2-2b from SRT is unstable on the commented prompt
43
43
  # "The capital of France is",
44
44
  ]
45
+ TEST_RERANK_QUERY_DOCS = [
46
+ {
47
+ "query": "How many people live in Berlin?",
48
+ "documents": [
49
+ "Berlin is well known for its museums.",
50
+ ],
51
+ },
52
+ {
53
+ "query": "How many people live in Berlin?",
54
+ "documents": [
55
+ "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
56
+ "Berlin is well known for its museums.",
57
+ ],
58
+ },
59
+ ]
45
60
 
46
61
  dirpath = os.path.dirname(__file__)
47
62
  with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
@@ -241,7 +256,7 @@ class HFRunner:
241
256
  self.model = _get_sentence_transformer_embedding_model(
242
257
  model_path, torch_dtype
243
258
  )
244
- elif self.model_type == "reward":
259
+ elif self.model_type == "reward" or self.model_type == "cross_encoder":
245
260
  from transformers import AutoModelForSequenceClassification
246
261
 
247
262
  self.model = AutoModelForSequenceClassification.from_pretrained(
@@ -303,6 +318,15 @@ class HFRunner:
303
318
  else:
304
319
  logits = self.model.encode(prompts).tolist()
305
320
  out_queue.put(ModelOutput(embed_logits=logits))
321
+ elif self.model_type == "cross_encoder":
322
+ inputs = self.tokenizer(
323
+ prompts, padding=True, return_tensors="pt"
324
+ ).to("cuda")
325
+ scores = self.model(**inputs).logits
326
+ scores = scores.squeeze().tolist()
327
+ if not isinstance(scores, list):
328
+ scores = [scores]
329
+ out_queue.put(ModelOutput(scores=scores))
306
330
 
307
331
  elif self.model_type == "reward":
308
332
  scores = []
@@ -322,7 +346,9 @@ class HFRunner:
322
346
 
323
347
  def forward(
324
348
  self,
325
- prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
349
+ prompts: Union[
350
+ List[List[str]], List[str], List[torch.Tensor]
351
+ ] = DEFAULT_PROMPTS,
326
352
  image_data: Optional[List[str]] = None,
327
353
  max_new_tokens: int = 8,
328
354
  lora_paths: Optional[List[str]] = None,
@@ -526,7 +552,9 @@ class SRTRunner:
526
552
 
527
553
  def forward(
528
554
  self,
529
- prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
555
+ prompts: Union[
556
+ List[List[str]], List[str], List[torch.Tensor]
557
+ ] = DEFAULT_PROMPTS,
530
558
  image_data: Optional[List[str]] = None,
531
559
  max_new_tokens: int = 8,
532
560
  lora_paths: Optional[List[str]] = None,
@@ -552,6 +580,13 @@ class SRTRunner:
552
580
  else:
553
581
  logits = [response["embedding"]]
554
582
  return ModelOutput(embed_logits=logits)
583
+ # cross encoder model
584
+ elif self.model_type == "cross_encoder":
585
+ response = self.engine.rerank(prompts)
586
+ if not isinstance(response, list):
587
+ response = [response]
588
+ scores = [x["embedding"] for x in response]
589
+ return ModelOutput(scores=scores)
555
590
  # reward model
556
591
  else:
557
592
  response = self.engine.encode(prompts)
@@ -343,6 +343,7 @@ class TestW8A8BlockFP8Matmul(CustomTestCase):
343
343
  OUT_DTYPES = [torch.bfloat16]
344
344
  M = [64, 128, 512, 1024, 4096]
345
345
  NKs = [
346
+ (2112, 7168),
346
347
  (1536, 7168),
347
348
  (3072, 1536),
348
349
  (24576, 7168),