sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -35,11 +35,17 @@ from sglang.srt.speculative.eagle_utils import (
|
|
35
35
|
EagleVerifyInput,
|
36
36
|
EagleVerifyOutput,
|
37
37
|
assign_draft_cache_locs,
|
38
|
+
fast_topk,
|
38
39
|
generate_token_bitmask,
|
39
40
|
select_top_k_tokens,
|
40
41
|
)
|
41
42
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
42
|
-
from sglang.srt.utils import
|
43
|
+
from sglang.srt.utils import (
|
44
|
+
empty_context,
|
45
|
+
get_available_gpu_memory,
|
46
|
+
is_cuda,
|
47
|
+
next_power_of_2,
|
48
|
+
)
|
43
49
|
|
44
50
|
if is_cuda():
|
45
51
|
from sgl_kernel import segment_packbits
|
@@ -152,6 +158,12 @@ class EAGLEWorker(TpModelWorker):
|
|
152
158
|
self.init_attention_backend()
|
153
159
|
self.init_cuda_graphs()
|
154
160
|
|
161
|
+
# Some dummy tensors
|
162
|
+
self.num_new_pages_per_topk = torch.empty(
|
163
|
+
(), dtype=torch.int64, device=self.device
|
164
|
+
)
|
165
|
+
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|
166
|
+
|
155
167
|
def init_attention_backend(self):
|
156
168
|
# Create multi-step attn backends and cuda graph runners
|
157
169
|
if self.server_args.attention_backend == "flashinfer":
|
@@ -254,7 +266,7 @@ class EAGLEWorker(TpModelWorker):
|
|
254
266
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
255
267
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
256
268
|
logger.info(
|
257
|
-
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s.
|
269
|
+
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
258
270
|
)
|
259
271
|
|
260
272
|
# Capture extend
|
@@ -269,7 +281,7 @@ class EAGLEWorker(TpModelWorker):
|
|
269
281
|
)
|
270
282
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
271
283
|
logger.info(
|
272
|
-
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s.
|
284
|
+
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
273
285
|
)
|
274
286
|
|
275
287
|
@property
|
@@ -365,14 +377,21 @@ class EAGLEWorker(TpModelWorker):
|
|
365
377
|
)
|
366
378
|
|
367
379
|
# Allocate cache locations
|
380
|
+
# Layout of the out_cache_loc
|
381
|
+
# [ topk 0 ] [ topk 1 ]
|
382
|
+
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
|
368
383
|
if self.page_size == 1:
|
369
384
|
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
370
|
-
num_seqs * self.
|
385
|
+
num_seqs * self.speculative_num_steps * self.topk, backup_state=True
|
371
386
|
)
|
372
387
|
else:
|
373
388
|
if self.topk == 1:
|
374
|
-
prefix_lens =
|
375
|
-
|
389
|
+
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
|
390
|
+
batch.req_to_token_pool.req_to_token,
|
391
|
+
batch.req_pool_indices,
|
392
|
+
batch.seq_lens,
|
393
|
+
self.speculative_num_steps,
|
394
|
+
)
|
376
395
|
extend_num_tokens = num_seqs * self.speculative_num_steps
|
377
396
|
else:
|
378
397
|
# In this case, the last partial page needs to be duplicated.
|
@@ -385,29 +404,33 @@ class EAGLEWorker(TpModelWorker):
|
|
385
404
|
# "x" means speculative draft tokens
|
386
405
|
# "." means padded tokens
|
387
406
|
|
388
|
-
# TODO:
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
407
|
+
# TODO(lmzheng): The current implementation is still a fake support
|
408
|
+
# for page size > 1. In the `assign_draft_cache_locs` below,
|
409
|
+
# we directly move the indices instead of the real kv cache.
|
410
|
+
# This only works when the kernel backend runs with page size = 1.
|
411
|
+
# If the kernel backend runs with page size > 1, we need to
|
412
|
+
# duplicate the real KV cache. The overhead of duplicating KV
|
413
|
+
# cache seems okay because the draft KV cache only has one layer.
|
414
|
+
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
|
415
|
+
|
416
|
+
(
|
417
|
+
prefix_lens,
|
418
|
+
seq_lens,
|
419
|
+
last_loc,
|
420
|
+
self.num_new_pages_per_topk,
|
421
|
+
self.extend_lens,
|
422
|
+
) = get_last_loc_large_page_size_large_top_k(
|
423
|
+
batch.req_to_token_pool.req_to_token,
|
424
|
+
batch.req_pool_indices,
|
425
|
+
batch.seq_lens,
|
426
|
+
self.speculative_num_steps,
|
427
|
+
self.topk,
|
428
|
+
self.page_size,
|
401
429
|
)
|
402
|
-
|
403
|
-
#
|
404
|
-
|
405
|
-
|
406
|
-
last_loc = get_last_loc(
|
407
|
-
batch.req_to_token_pool.req_to_token,
|
408
|
-
batch.req_pool_indices,
|
409
|
-
prefix_lens,
|
410
|
-
)
|
430
|
+
|
431
|
+
# TODO(lmzheng): remove this device sync
|
432
|
+
extend_num_tokens = torch.sum(self.extend_lens).item()
|
433
|
+
|
411
434
|
out_cache_loc, token_to_kv_pool_state_backup = (
|
412
435
|
batch.alloc_paged_token_slots_extend(
|
413
436
|
prefix_lens,
|
@@ -422,18 +445,30 @@ class EAGLEWorker(TpModelWorker):
|
|
422
445
|
batch.req_pool_indices,
|
423
446
|
batch.req_to_token_pool.req_to_token,
|
424
447
|
batch.seq_lens,
|
448
|
+
self.extend_lens,
|
449
|
+
self.num_new_pages_per_topk,
|
425
450
|
out_cache_loc,
|
426
451
|
batch.req_to_token_pool.req_to_token.shape[1],
|
427
452
|
self.topk,
|
428
453
|
self.speculative_num_steps,
|
429
454
|
self.page_size,
|
455
|
+
next_power_of_2(num_seqs),
|
456
|
+
next_power_of_2(self.speculative_num_steps),
|
430
457
|
)
|
458
|
+
|
459
|
+
if self.page_size > 1 and self.topk > 1:
|
460
|
+
# Remove padded slots
|
461
|
+
out_cache_loc = out_cache_loc[
|
462
|
+
: num_seqs * self.topk * self.speculative_num_steps
|
463
|
+
]
|
464
|
+
|
431
465
|
batch.out_cache_loc = out_cache_loc
|
432
466
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
467
|
+
batch.return_hidden_states = False
|
433
468
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
469
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
434
470
|
|
435
471
|
# Get forward batch
|
436
|
-
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
437
472
|
model_worker_batch = batch.get_model_worker_batch()
|
438
473
|
forward_batch = ForwardBatch.init_new(
|
439
474
|
model_worker_batch, self.draft_model_runner
|
@@ -448,9 +483,6 @@ class EAGLEWorker(TpModelWorker):
|
|
448
483
|
else:
|
449
484
|
# Initialize attention backend
|
450
485
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
451
|
-
forward_batch = ForwardBatch.init_new(
|
452
|
-
model_worker_batch, self.draft_model_runner
|
453
|
-
)
|
454
486
|
# Run forward steps
|
455
487
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
456
488
|
|
@@ -503,6 +535,13 @@ class EAGLEWorker(TpModelWorker):
|
|
503
535
|
if self.hot_token_id is not None:
|
504
536
|
topk_index = self.hot_token_id[topk_index]
|
505
537
|
|
538
|
+
out_cache_loc = out_cache_loc.reshape(
|
539
|
+
forward_batch.batch_size, self.topk, self.speculative_num_steps
|
540
|
+
)
|
541
|
+
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
|
542
|
+
self.speculative_num_steps, -1
|
543
|
+
)
|
544
|
+
|
506
545
|
# Return values
|
507
546
|
score_list: List[torch.Tensor] = []
|
508
547
|
token_list: List[torch.Tensor] = []
|
@@ -524,10 +563,7 @@ class EAGLEWorker(TpModelWorker):
|
|
524
563
|
|
525
564
|
# Set inputs
|
526
565
|
forward_batch.input_ids = input_ids
|
527
|
-
out_cache_loc = out_cache_loc
|
528
|
-
forward_batch.out_cache_loc = out_cache_loc[
|
529
|
-
:, self.topk * i : self.topk * (i + 1)
|
530
|
-
].flatten()
|
566
|
+
forward_batch.out_cache_loc = out_cache_loc[i]
|
531
567
|
forward_batch.positions.add_(1)
|
532
568
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
533
569
|
spec_info.hidden_states = hidden_states
|
@@ -547,11 +583,13 @@ class EAGLEWorker(TpModelWorker):
|
|
547
583
|
|
548
584
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
549
585
|
spec_info.prepare_for_verify(batch, self.page_size)
|
586
|
+
batch.return_hidden_states = False
|
550
587
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
551
588
|
batch.spec_info = spec_info
|
552
589
|
model_worker_batch = batch.get_model_worker_batch(
|
553
590
|
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
554
591
|
)
|
592
|
+
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
555
593
|
|
556
594
|
if batch.has_grammar:
|
557
595
|
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
@@ -583,7 +621,7 @@ class EAGLEWorker(TpModelWorker):
|
|
583
621
|
if vocab_mask is not None:
|
584
622
|
assert spec_info.grammar is not None
|
585
623
|
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
586
|
-
# otherwise, this vocab mask will be the one from the previous extend stage
|
624
|
+
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
|
587
625
|
# and will be applied to produce wrong results
|
588
626
|
batch.sampling_info.vocab_mask = None
|
589
627
|
|
@@ -604,13 +642,13 @@ class EAGLEWorker(TpModelWorker):
|
|
604
642
|
]
|
605
643
|
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
606
644
|
|
645
|
+
if batch.return_logprob:
|
646
|
+
self.add_logprob_values(batch, res, logits_output)
|
647
|
+
|
607
648
|
# Prepare the batch for the next draft forwards.
|
608
649
|
batch.forward_mode = ForwardMode.DECODE
|
609
650
|
batch.spec_info = res.draft_input
|
610
651
|
|
611
|
-
if batch.return_logprob:
|
612
|
-
self.add_logprob_values(batch, res, logits_output)
|
613
|
-
|
614
652
|
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
615
653
|
|
616
654
|
def add_logprob_values(
|
@@ -623,8 +661,16 @@ class EAGLEWorker(TpModelWorker):
|
|
623
661
|
logits_output = res.logits_output
|
624
662
|
top_logprobs_nums = batch.top_logprobs_nums
|
625
663
|
token_ids_logprobs = batch.token_ids_logprobs
|
664
|
+
accepted_indices = res.accepted_indices
|
665
|
+
assert len(accepted_indices) == len(logits_output.next_token_logits)
|
666
|
+
temperatures = batch.sampling_info.temperatures
|
667
|
+
num_draft_tokens = batch.spec_info.draft_token_num
|
668
|
+
# acceptance indices are the indices in a "flattened" batch.
|
669
|
+
# dividing it to num_draft_tokens will yield the actual batch index.
|
670
|
+
temperatures = temperatures[accepted_indices // num_draft_tokens]
|
671
|
+
|
626
672
|
logprobs = torch.nn.functional.log_softmax(
|
627
|
-
logits_output.next_token_logits, dim=-1
|
673
|
+
logits_output.next_token_logits / temperatures, dim=-1
|
628
674
|
)
|
629
675
|
batch_next_token_ids = res.verified_id
|
630
676
|
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
@@ -659,7 +705,7 @@ class EAGLEWorker(TpModelWorker):
|
|
659
705
|
pt = 0
|
660
706
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
661
707
|
verified_ids = batch_next_token_ids.tolist()
|
662
|
-
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
708
|
+
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
|
663
709
|
for _ in range(num_tokens):
|
664
710
|
if req.return_logprob:
|
665
711
|
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
@@ -691,6 +737,7 @@ class EAGLEWorker(TpModelWorker):
|
|
691
737
|
hidden_states=hidden_states,
|
692
738
|
verified_id=next_token_ids,
|
693
739
|
)
|
740
|
+
batch.return_hidden_states = False
|
694
741
|
batch.spec_info.prepare_for_extend(batch)
|
695
742
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
696
743
|
model_worker_batch = batch.get_model_worker_batch(
|
@@ -781,4 +828,48 @@ def load_token_map(token_map_path: str) -> List[int]:
|
|
781
828
|
)
|
782
829
|
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
783
830
|
hot_token_id = torch.load(token_map_path, weights_only=True)
|
784
|
-
return torch.tensor(hot_token_id, dtype=torch.
|
831
|
+
return torch.tensor(hot_token_id, dtype=torch.int64)
|
832
|
+
|
833
|
+
|
834
|
+
@torch.compile(dynamic=True)
|
835
|
+
def get_last_loc_large_page_size_top_k_1(
|
836
|
+
req_to_token: torch.Tensor,
|
837
|
+
req_pool_indices: torch.Tensor,
|
838
|
+
seq_lens,
|
839
|
+
speculative_num_steps: int,
|
840
|
+
):
|
841
|
+
prefix_lens = seq_lens
|
842
|
+
seq_lens = prefix_lens + speculative_num_steps
|
843
|
+
last_loc = get_last_loc(
|
844
|
+
req_to_token,
|
845
|
+
req_pool_indices,
|
846
|
+
prefix_lens,
|
847
|
+
)
|
848
|
+
return prefix_lens, seq_lens, last_loc
|
849
|
+
|
850
|
+
|
851
|
+
@torch.compile(dynamic=True)
|
852
|
+
def get_last_loc_large_page_size_large_top_k(
|
853
|
+
req_to_token: torch.Tensor,
|
854
|
+
req_pool_indices: torch.Tensor,
|
855
|
+
seq_lens: torch.Tensor,
|
856
|
+
speculative_num_steps: int,
|
857
|
+
topk: int,
|
858
|
+
page_size: int,
|
859
|
+
):
|
860
|
+
prefix_lens = seq_lens
|
861
|
+
last_page_lens = prefix_lens % page_size
|
862
|
+
num_new_pages_per_topk = (
|
863
|
+
last_page_lens + speculative_num_steps + page_size - 1
|
864
|
+
) // page_size
|
865
|
+
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
|
866
|
+
page_size * topk
|
867
|
+
)
|
868
|
+
extend_lens = seq_lens - prefix_lens
|
869
|
+
last_loc = get_last_loc(
|
870
|
+
req_to_token,
|
871
|
+
req_pool_indices,
|
872
|
+
prefix_lens,
|
873
|
+
)
|
874
|
+
|
875
|
+
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
|
sglang/srt/two_batch_overlap.py
CHANGED
@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import (
|
|
11
11
|
ScatterMode,
|
12
12
|
)
|
13
13
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
14
|
-
from sglang.srt.layers.quantization
|
14
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
15
15
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
16
16
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
17
17
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
@@ -479,7 +479,9 @@ def _model_forward_tbo(
|
|
479
479
|
)
|
480
480
|
del inputs
|
481
481
|
|
482
|
-
with configure_deep_gemm_num_sms(
|
482
|
+
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
483
|
+
operations_strategy.deep_gemm_num_sms
|
484
|
+
):
|
483
485
|
outputs_arr = execute_overlapped_operations(
|
484
486
|
inputs_arr=inputs_arr,
|
485
487
|
operations_arr=[operations_strategy.operations] * 2,
|
sglang/srt/utils.py
CHANGED
@@ -17,6 +17,7 @@ import base64
|
|
17
17
|
import builtins
|
18
18
|
import ctypes
|
19
19
|
import dataclasses
|
20
|
+
import functools
|
20
21
|
import importlib
|
21
22
|
import io
|
22
23
|
import ipaddress
|
@@ -837,6 +838,7 @@ class CustomCacheManager(FileCacheManager):
|
|
837
838
|
|
838
839
|
|
839
840
|
def set_ulimit(target_soft_limit=65535):
|
841
|
+
# number of open files
|
840
842
|
resource_type = resource.RLIMIT_NOFILE
|
841
843
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
842
844
|
|
@@ -846,6 +848,18 @@ def set_ulimit(target_soft_limit=65535):
|
|
846
848
|
except ValueError as e:
|
847
849
|
logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
|
848
850
|
|
851
|
+
# stack size
|
852
|
+
resource_type = resource.RLIMIT_STACK
|
853
|
+
current_soft, current_hard = resource.getrlimit(resource_type)
|
854
|
+
target_soft_limit_stack_size = 1024 * target_soft_limit
|
855
|
+
if current_soft < target_soft_limit_stack_size:
|
856
|
+
try:
|
857
|
+
resource.setrlimit(
|
858
|
+
resource_type, (target_soft_limit_stack_size, current_hard)
|
859
|
+
)
|
860
|
+
except ValueError as e:
|
861
|
+
logger.warning(f"Fail to set RLIMIT_STACK: {e}")
|
862
|
+
|
849
863
|
|
850
864
|
def add_api_key_middleware(app, api_key: str):
|
851
865
|
@app.middleware("http")
|
@@ -1373,6 +1387,11 @@ def print_warning_once(msg: str) -> None:
|
|
1373
1387
|
logger.warning(msg, stacklevel=2)
|
1374
1388
|
|
1375
1389
|
|
1390
|
+
@functools.lru_cache(None)
|
1391
|
+
def print_info_once(msg: str) -> None:
|
1392
|
+
logger.info(msg)
|
1393
|
+
|
1394
|
+
|
1376
1395
|
def get_device_name(device_id: int = 0) -> str:
|
1377
1396
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
1378
1397
|
return torch.cuda.get_device_name(device_id)
|
@@ -2197,6 +2216,45 @@ class Withable(Generic[T]):
|
|
2197
2216
|
self._value = None
|
2198
2217
|
|
2199
2218
|
|
2219
|
+
def merge_bias_tensor(
|
2220
|
+
lhs: Optional[torch.Tensor],
|
2221
|
+
rhs: Optional[torch.Tensor],
|
2222
|
+
bs1: int,
|
2223
|
+
bs2: int,
|
2224
|
+
device: str,
|
2225
|
+
default: float,
|
2226
|
+
):
|
2227
|
+
"""Merge two bias tensors for batch merging.
|
2228
|
+
|
2229
|
+
Args:
|
2230
|
+
lhs: Left-hand side tensor
|
2231
|
+
rhs: Right-hand side tensor
|
2232
|
+
bs1: Batch size of left-hand side tensor
|
2233
|
+
bs2: Batch size of right-hand side tensor
|
2234
|
+
device: Device to place the merged tensor on
|
2235
|
+
default: Default value for missing tensor elements
|
2236
|
+
|
2237
|
+
Returns:
|
2238
|
+
Merged tensor or None if both inputs are None
|
2239
|
+
"""
|
2240
|
+
if lhs is None and rhs is None:
|
2241
|
+
return None
|
2242
|
+
|
2243
|
+
if lhs is not None and rhs is not None:
|
2244
|
+
return torch.cat([lhs, rhs])
|
2245
|
+
else:
|
2246
|
+
if lhs is not None:
|
2247
|
+
shape, dtype = lhs.shape[1:], lhs.dtype
|
2248
|
+
else:
|
2249
|
+
shape, dtype = rhs.shape[1:], rhs.dtype
|
2250
|
+
|
2251
|
+
if lhs is None:
|
2252
|
+
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
|
2253
|
+
if rhs is None:
|
2254
|
+
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
|
2255
|
+
return torch.cat([lhs, rhs])
|
2256
|
+
|
2257
|
+
|
2200
2258
|
def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
|
2201
2259
|
import huggingface_hub as hf
|
2202
2260
|
|
@@ -2,6 +2,8 @@ import unittest
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
6
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
5
7
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
6
8
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
7
9
|
from sglang.test.test_utils import CustomTestCase
|
sglang/test/runners.py
CHANGED
@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
|
|
42
42
|
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
43
43
|
# "The capital of France is",
|
44
44
|
]
|
45
|
+
TEST_RERANK_QUERY_DOCS = [
|
46
|
+
{
|
47
|
+
"query": "How many people live in Berlin?",
|
48
|
+
"documents": [
|
49
|
+
"Berlin is well known for its museums.",
|
50
|
+
],
|
51
|
+
},
|
52
|
+
{
|
53
|
+
"query": "How many people live in Berlin?",
|
54
|
+
"documents": [
|
55
|
+
"Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
|
56
|
+
"Berlin is well known for its museums.",
|
57
|
+
],
|
58
|
+
},
|
59
|
+
]
|
45
60
|
|
46
61
|
dirpath = os.path.dirname(__file__)
|
47
62
|
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
|
@@ -241,7 +256,7 @@ class HFRunner:
|
|
241
256
|
self.model = _get_sentence_transformer_embedding_model(
|
242
257
|
model_path, torch_dtype
|
243
258
|
)
|
244
|
-
elif self.model_type == "reward":
|
259
|
+
elif self.model_type == "reward" or self.model_type == "cross_encoder":
|
245
260
|
from transformers import AutoModelForSequenceClassification
|
246
261
|
|
247
262
|
self.model = AutoModelForSequenceClassification.from_pretrained(
|
@@ -303,6 +318,15 @@ class HFRunner:
|
|
303
318
|
else:
|
304
319
|
logits = self.model.encode(prompts).tolist()
|
305
320
|
out_queue.put(ModelOutput(embed_logits=logits))
|
321
|
+
elif self.model_type == "cross_encoder":
|
322
|
+
inputs = self.tokenizer(
|
323
|
+
prompts, padding=True, return_tensors="pt"
|
324
|
+
).to("cuda")
|
325
|
+
scores = self.model(**inputs).logits
|
326
|
+
scores = scores.squeeze().tolist()
|
327
|
+
if not isinstance(scores, list):
|
328
|
+
scores = [scores]
|
329
|
+
out_queue.put(ModelOutput(scores=scores))
|
306
330
|
|
307
331
|
elif self.model_type == "reward":
|
308
332
|
scores = []
|
@@ -322,7 +346,9 @@ class HFRunner:
|
|
322
346
|
|
323
347
|
def forward(
|
324
348
|
self,
|
325
|
-
prompts: Union[
|
349
|
+
prompts: Union[
|
350
|
+
List[List[str]], List[str], List[torch.Tensor]
|
351
|
+
] = DEFAULT_PROMPTS,
|
326
352
|
image_data: Optional[List[str]] = None,
|
327
353
|
max_new_tokens: int = 8,
|
328
354
|
lora_paths: Optional[List[str]] = None,
|
@@ -526,7 +552,9 @@ class SRTRunner:
|
|
526
552
|
|
527
553
|
def forward(
|
528
554
|
self,
|
529
|
-
prompts: Union[
|
555
|
+
prompts: Union[
|
556
|
+
List[List[str]], List[str], List[torch.Tensor]
|
557
|
+
] = DEFAULT_PROMPTS,
|
530
558
|
image_data: Optional[List[str]] = None,
|
531
559
|
max_new_tokens: int = 8,
|
532
560
|
lora_paths: Optional[List[str]] = None,
|
@@ -552,6 +580,13 @@ class SRTRunner:
|
|
552
580
|
else:
|
553
581
|
logits = [response["embedding"]]
|
554
582
|
return ModelOutput(embed_logits=logits)
|
583
|
+
# cross encoder model
|
584
|
+
elif self.model_type == "cross_encoder":
|
585
|
+
response = self.engine.rerank(prompts)
|
586
|
+
if not isinstance(response, list):
|
587
|
+
response = [response]
|
588
|
+
scores = [x["embedding"] for x in response]
|
589
|
+
return ModelOutput(scores=scores)
|
555
590
|
# reward model
|
556
591
|
else:
|
557
592
|
response = self.engine.encode(prompts)
|