sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
|
|
72
72
|
GenerateReqInput,
|
73
73
|
GetWeightsByNameReqInput,
|
74
74
|
InitWeightsUpdateGroupReqInput,
|
75
|
+
LoadLoRAAdapterReqInput,
|
75
76
|
OpenSessionReqInput,
|
76
77
|
ParseFunctionCallReq,
|
77
78
|
ProfileReqInput,
|
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
|
|
80
81
|
SeparateReasoningReqInput,
|
81
82
|
SetInternalStateReq,
|
82
83
|
SlowDownReqInput,
|
84
|
+
UnloadLoRAAdapterReqInput,
|
83
85
|
UpdateWeightFromDiskReqInput,
|
84
86
|
UpdateWeightsFromDistributedReqInput,
|
85
87
|
UpdateWeightsFromTensorReqInput,
|
@@ -124,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
|
|
124
126
|
|
125
127
|
@asynccontextmanager
|
126
128
|
async def lifespan(fast_api_app: FastAPI):
|
127
|
-
server_args: ServerArgs = fast_api_app.server_args
|
128
|
-
|
129
129
|
# Initialize OpenAI serving handlers
|
130
130
|
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
|
131
131
|
_global_state.tokenizer_manager, _global_state.template_manager
|
@@ -143,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
|
|
143
143
|
_global_state.tokenizer_manager
|
144
144
|
)
|
145
145
|
|
146
|
+
server_args: ServerArgs = fast_api_app.server_args
|
146
147
|
if server_args.warmups is not None:
|
147
148
|
await execute_warmups(
|
148
|
-
server_args.
|
149
|
+
server_args.disaggregation_mode,
|
150
|
+
server_args.warmups.split(","),
|
151
|
+
_global_state.tokenizer_manager,
|
149
152
|
)
|
150
153
|
logger.info("Warmup ended")
|
151
154
|
|
@@ -278,13 +281,17 @@ async def get_model_info():
|
|
278
281
|
"model_path": _global_state.tokenizer_manager.model_path,
|
279
282
|
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
|
280
283
|
"is_generation": _global_state.tokenizer_manager.is_generation,
|
284
|
+
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
|
281
285
|
}
|
282
286
|
return result
|
283
287
|
|
284
288
|
|
285
289
|
@app.get("/get_server_info")
|
286
290
|
async def get_server_info():
|
287
|
-
|
291
|
+
# Returns interna states per DP.
|
292
|
+
internal_states: List[Dict[Any, Any]] = (
|
293
|
+
await _global_state.tokenizer_manager.get_internal_state()
|
294
|
+
)
|
288
295
|
return {
|
289
296
|
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
290
297
|
**_global_state.scheduler_info,
|
@@ -298,6 +305,8 @@ async def get_load():
|
|
298
305
|
return await _global_state.tokenizer_manager.get_load()
|
299
306
|
|
300
307
|
|
308
|
+
# example usage:
|
309
|
+
# curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
|
301
310
|
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
|
302
311
|
async def set_internal_state(obj: SetInternalStateReq, request: Request):
|
303
312
|
res = await _global_state.tokenizer_manager.set_internal_state(obj)
|
@@ -351,8 +360,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
|
|
351
360
|
obj = GenerateReqInput(
|
352
361
|
input_embeds=input_embeds,
|
353
362
|
sampling_params={
|
354
|
-
"
|
355
|
-
"temperature": 0.2,
|
363
|
+
"temperature": 0.0,
|
356
364
|
"max_new_tokens": 512,
|
357
365
|
},
|
358
366
|
)
|
@@ -391,16 +399,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
391
399
|
return _create_error_response(e)
|
392
400
|
|
393
401
|
|
394
|
-
@app.api_route(
|
395
|
-
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
396
|
-
)
|
397
|
-
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
398
|
-
"""Endpoint for reranking documents based on query relevance."""
|
399
|
-
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
400
|
-
request, raw_request
|
401
|
-
)
|
402
|
-
|
403
|
-
|
404
402
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
405
403
|
async def flush_cache():
|
406
404
|
"""Flush the radix cache."""
|
@@ -595,6 +593,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
|
|
595
593
|
return _create_error_response(e)
|
596
594
|
|
597
595
|
|
596
|
+
@app.api_route("/load_lora_adapter", methods=["POST"])
|
597
|
+
async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
|
598
|
+
"""Load a new LoRA adapter without re-launching the server."""
|
599
|
+
result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)
|
600
|
+
|
601
|
+
if result.success:
|
602
|
+
return ORJSONResponse(
|
603
|
+
result,
|
604
|
+
status_code=HTTPStatus.OK,
|
605
|
+
)
|
606
|
+
else:
|
607
|
+
return ORJSONResponse(
|
608
|
+
result,
|
609
|
+
status_code=HTTPStatus.BAD_REQUEST,
|
610
|
+
)
|
611
|
+
|
612
|
+
|
613
|
+
@app.api_route("/unload_lora_adapter", methods=["POST"])
|
614
|
+
async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
|
615
|
+
"""Load a new LoRA adapter without re-launching the server."""
|
616
|
+
result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)
|
617
|
+
|
618
|
+
if result.success:
|
619
|
+
return ORJSONResponse(
|
620
|
+
result,
|
621
|
+
status_code=HTTPStatus.OK,
|
622
|
+
)
|
623
|
+
else:
|
624
|
+
return ORJSONResponse(
|
625
|
+
result,
|
626
|
+
status_code=HTTPStatus.BAD_REQUEST,
|
627
|
+
)
|
628
|
+
|
629
|
+
|
598
630
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
599
631
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
600
632
|
"""Open a session, and return its unique session id."""
|
@@ -630,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
|
630
662
|
async def abort_request(obj: AbortReq, request: Request):
|
631
663
|
"""Abort a request."""
|
632
664
|
try:
|
633
|
-
_global_state.tokenizer_manager.abort_request(
|
665
|
+
_global_state.tokenizer_manager.abort_request(
|
666
|
+
rid=obj.rid, abort_all=obj.abort_all
|
667
|
+
)
|
634
668
|
return Response(status_code=200)
|
635
669
|
except Exception as e:
|
636
670
|
return _create_error_response(e)
|
@@ -678,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
|
|
678
712
|
return ORJSONResponse(content=response_data, status_code=200)
|
679
713
|
|
680
714
|
|
715
|
+
@app.post("/pause_generation")
|
716
|
+
async def pause_generation(request: Request):
|
717
|
+
"""Pause generation."""
|
718
|
+
await _global_state.tokenizer_manager.pause_generation()
|
719
|
+
return ORJSONResponse(
|
720
|
+
content={"message": "Generation paused successfully.", "status": "ok"},
|
721
|
+
status_code=200,
|
722
|
+
)
|
723
|
+
|
724
|
+
|
725
|
+
@app.post("/continue_generation")
|
726
|
+
async def continue_generation(request: Request):
|
727
|
+
"""Continue generation."""
|
728
|
+
await _global_state.tokenizer_manager.continue_generation()
|
729
|
+
return ORJSONResponse(
|
730
|
+
content={"message": "Generation continued successfully.", "status": "ok"},
|
731
|
+
status_code=200,
|
732
|
+
)
|
733
|
+
|
734
|
+
|
681
735
|
##### OpenAI-compatible API endpoints #####
|
682
736
|
|
683
737
|
|
@@ -805,6 +859,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
|
805
859
|
)
|
806
860
|
|
807
861
|
|
862
|
+
@app.api_route(
|
863
|
+
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
864
|
+
)
|
865
|
+
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
866
|
+
"""Endpoint for reranking documents based on query relevance."""
|
867
|
+
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
868
|
+
request, raw_request
|
869
|
+
)
|
870
|
+
|
871
|
+
|
808
872
|
def _create_error_response(e):
|
809
873
|
return ORJSONResponse(
|
810
874
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
@@ -851,6 +915,15 @@ def launch_server(
|
|
851
915
|
add_prometheus_middleware(app)
|
852
916
|
enable_func_timer()
|
853
917
|
|
918
|
+
image_token_text = None
|
919
|
+
if (
|
920
|
+
tokenizer_manager.image_token_id is not None
|
921
|
+
and not server_args.skip_tokenizer_init
|
922
|
+
):
|
923
|
+
image_token_text = tokenizer_manager.tokenizer.decode(
|
924
|
+
[tokenizer_manager.image_token_id]
|
925
|
+
)
|
926
|
+
|
854
927
|
# Send a warmup request - we will create the thread launch it
|
855
928
|
# in the lifespan after all other warmups have fired.
|
856
929
|
warmup_thread = threading.Thread(
|
@@ -858,7 +931,7 @@ def launch_server(
|
|
858
931
|
args=(
|
859
932
|
server_args,
|
860
933
|
pipe_finish_writer,
|
861
|
-
|
934
|
+
image_token_text,
|
862
935
|
launch_callback,
|
863
936
|
),
|
864
937
|
)
|
@@ -881,11 +954,9 @@ def launch_server(
|
|
881
954
|
warmup_thread.join()
|
882
955
|
|
883
956
|
|
884
|
-
def
|
957
|
+
def _execute_server_warmup(
|
885
958
|
server_args: ServerArgs,
|
886
959
|
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
887
|
-
image_token_text: str,
|
888
|
-
launch_callback: Optional[Callable[[], None]] = None,
|
889
960
|
):
|
890
961
|
headers = {}
|
891
962
|
url = server_args.url()
|
@@ -910,7 +981,7 @@ def _wait_and_warmup(
|
|
910
981
|
pipe_finish_writer.send(last_traceback)
|
911
982
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
912
983
|
kill_process_tree(os.getpid())
|
913
|
-
return
|
984
|
+
return success
|
914
985
|
|
915
986
|
model_info = res.json()
|
916
987
|
|
@@ -984,12 +1055,28 @@ def _wait_and_warmup(
|
|
984
1055
|
pipe_finish_writer.send(last_traceback)
|
985
1056
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
986
1057
|
kill_process_tree(os.getpid())
|
987
|
-
return
|
1058
|
+
return False
|
988
1059
|
|
989
1060
|
# Debug print
|
990
|
-
# logger.info(f"{res.json()=}")
|
1061
|
+
# logger.info(f"warmup request returns: {res.json()=}")
|
1062
|
+
return success
|
1063
|
+
|
1064
|
+
|
1065
|
+
def _wait_and_warmup(
|
1066
|
+
server_args: ServerArgs,
|
1067
|
+
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
1068
|
+
image_token_text: str,
|
1069
|
+
launch_callback: Optional[Callable[[], None]] = None,
|
1070
|
+
):
|
1071
|
+
if not server_args.skip_server_warmup:
|
1072
|
+
if not _execute_server_warmup(
|
1073
|
+
server_args,
|
1074
|
+
pipe_finish_writer,
|
1075
|
+
):
|
1076
|
+
return
|
991
1077
|
|
992
1078
|
logger.info("The server is fired up and ready to roll!")
|
1079
|
+
|
993
1080
|
if pipe_finish_writer is not None:
|
994
1081
|
pipe_finish_writer.send("ready")
|
995
1082
|
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import base64
|
2
1
|
import copy
|
3
2
|
import dataclasses
|
4
3
|
import multiprocessing
|
@@ -7,6 +6,7 @@ import threading
|
|
7
6
|
import time
|
8
7
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
9
8
|
|
9
|
+
import pybase64
|
10
10
|
import requests
|
11
11
|
import torch
|
12
12
|
import torch.distributed as dist
|
@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel):
|
|
236
236
|
index: int
|
237
237
|
text: str
|
238
238
|
logprobs: Optional[LogProbs] = None
|
239
|
-
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
239
|
+
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
|
240
240
|
matched_stop: Union[None, int, str] = None
|
241
241
|
hidden_states: Optional[object] = None
|
242
242
|
|
@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
|
510
510
|
delta: DeltaMessage
|
511
511
|
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
512
512
|
finish_reason: Optional[
|
513
|
-
Literal[
|
513
|
+
Literal[
|
514
|
+
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
515
|
+
]
|
514
516
|
] = None
|
515
517
|
matched_stop: Union[None, int, str] = None
|
516
518
|
|
File without changes
|
@@ -4,10 +4,8 @@ from typing import TYPE_CHECKING, List
|
|
4
4
|
|
5
5
|
import torch.cuda
|
6
6
|
|
7
|
-
from sglang.srt.
|
8
|
-
|
9
|
-
)
|
10
|
-
from sglang.srt.managers.expert_location import ExpertLocationMetadata
|
7
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
8
|
+
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
11
9
|
|
12
10
|
if TYPE_CHECKING:
|
13
11
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -24,7 +24,7 @@ import einops
|
|
24
24
|
import torch
|
25
25
|
import torch.distributed
|
26
26
|
|
27
|
-
from sglang.srt.
|
27
|
+
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
28
28
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
29
29
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
30
|
from sglang.srt.server_args import ServerArgs
|
@@ -479,10 +479,6 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
|
479
479
|
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
480
480
|
topk_ids = topk_ids.flatten()
|
481
481
|
mask = topk_ids != -1
|
482
|
-
assert self._data[layer_idx, :].shape == topk_ids.shape, (
|
483
|
-
"Shape mismatch between data and topk_ids."
|
484
|
-
"Selecting expert is not supported for multiple token prediction at the moment."
|
485
|
-
)
|
486
482
|
self._data[layer_idx, :].scatter_add_(
|
487
483
|
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
488
484
|
)
|
@@ -23,7 +23,7 @@ import torch.distributed
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
|
25
25
|
from sglang.srt.configs.model_config import ModelConfig
|
26
|
-
from sglang.srt.
|
26
|
+
from sglang.srt.eplb import eplb_algorithms
|
27
27
|
from sglang.srt.model_loader import get_model_architecture
|
28
28
|
from sglang.srt.server_args import ServerArgs
|
29
29
|
|
@@ -17,7 +17,7 @@ from typing import Literal, Optional
|
|
17
17
|
|
18
18
|
import torch
|
19
19
|
|
20
|
-
from sglang.srt.
|
20
|
+
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
21
21
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
22
22
|
|
23
23
|
|
@@ -20,7 +20,7 @@ import torch
|
|
20
20
|
import torch.distributed
|
21
21
|
from torch.distributed import P2POp
|
22
22
|
|
23
|
-
from sglang.srt.
|
23
|
+
from sglang.srt.eplb.expert_location import (
|
24
24
|
ExpertLocationMetadata,
|
25
25
|
get_global_expert_location_metadata,
|
26
26
|
)
|
@@ -30,6 +30,9 @@ from sglang.srt.utils import get_bool_env_var
|
|
30
30
|
logger = logging.getLogger(__name__)
|
31
31
|
|
32
32
|
|
33
|
+
_LOG_INPUT = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_INPUT")
|
34
|
+
|
35
|
+
|
33
36
|
class ExpertLocationUpdater:
|
34
37
|
def __init__(self):
|
35
38
|
self._first_execution = True
|
@@ -175,6 +178,19 @@ def update_expert_weights_single_layer(
|
|
175
178
|
assert isinstance(old_physical_to_logical_map, list)
|
176
179
|
assert isinstance(new_physical_to_logical_map, list)
|
177
180
|
|
181
|
+
if _LOG_INPUT:
|
182
|
+
logger.info(
|
183
|
+
"update_expert_weights_single_layer "
|
184
|
+
f"{[x.shape for x in routed_experts_weights]=} "
|
185
|
+
f"{[x.shape for x in temp_buffers]=} "
|
186
|
+
f"{old_physical_to_logical_map=} "
|
187
|
+
f"{new_physical_to_logical_map=} "
|
188
|
+
f"{num_local_physical_experts=} "
|
189
|
+
f"{num_gpu_per_node=} "
|
190
|
+
f"{rank=} "
|
191
|
+
f"{world_size=} "
|
192
|
+
)
|
193
|
+
|
178
194
|
output_logs = [] if debug else None
|
179
195
|
|
180
196
|
num_physical_experts = len(old_physical_to_logical_map)
|
@@ -42,7 +42,7 @@ from sglang.srt.configs import (
|
|
42
42
|
)
|
43
43
|
from sglang.srt.configs.internvl import InternVLChatConfig
|
44
44
|
from sglang.srt.connector import create_remote_connector
|
45
|
-
from sglang.srt.utils import is_remote_url
|
45
|
+
from sglang.srt.utils import is_remote_url, lru_cache_frozenset
|
46
46
|
|
47
47
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
48
48
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
@@ -103,6 +103,7 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
103
103
|
return config
|
104
104
|
|
105
105
|
|
106
|
+
@lru_cache_frozenset(maxsize=32)
|
106
107
|
def get_config(
|
107
108
|
model: str,
|
108
109
|
trust_remote_code: bool,
|
sglang/srt/layers/activation.py
CHANGED
@@ -46,11 +46,11 @@ _is_cpu = is_cpu()
|
|
46
46
|
if _is_cuda:
|
47
47
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
48
48
|
|
49
|
-
logger = logging.getLogger(__name__)
|
50
|
-
|
51
49
|
if is_npu():
|
52
50
|
import torch_npu
|
53
51
|
|
52
|
+
logger = logging.getLogger(__name__)
|
53
|
+
|
54
54
|
|
55
55
|
class SiluAndMul(CustomOp):
|
56
56
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -0,0 +1,86 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.utils import cpu_has_amx_support
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
def amx_process_weight_after_loading(weight):
|
11
|
+
if weight.device != torch.device("cpu"):
|
12
|
+
return weight
|
13
|
+
if not cpu_has_amx_support():
|
14
|
+
return weight
|
15
|
+
|
16
|
+
return torch.ops.sgl_kernel.convert_weight_packed(weight)
|
17
|
+
|
18
|
+
|
19
|
+
# TODO: currently gemm kernel has the below requirements:
|
20
|
+
# OC % TILE_N == 0, where TILE_N = 16
|
21
|
+
# IC % TILE_K == 0, where TILE_K = 32
|
22
|
+
def dim_is_supported(weight):
|
23
|
+
TILE_N = 16
|
24
|
+
TILE_K = 32
|
25
|
+
ndim = weight.ndim
|
26
|
+
OC = weight.size(1) if ndim == 3 else weight.size(0)
|
27
|
+
IC = weight.size(2) if ndim == 3 else weight.size(1)
|
28
|
+
return OC % TILE_N == 0 and IC % TILE_K == 0
|
29
|
+
|
30
|
+
|
31
|
+
def _amx_process_weight_after_loading(
|
32
|
+
module, weight_names, transpose_dims=None
|
33
|
+
) -> None:
|
34
|
+
# Pack weight for get better performance on CPU
|
35
|
+
devices = {getattr(module, weight_name).device for weight_name in weight_names}
|
36
|
+
assert len(devices) == 1, f"Expects all weights to be on the same device"
|
37
|
+
device = devices.pop()
|
38
|
+
|
39
|
+
if transpose_dims:
|
40
|
+
assert len(weight_names) == len(
|
41
|
+
transpose_dims
|
42
|
+
), "len(weight_names) should be equal to len(transpose_dims)"
|
43
|
+
|
44
|
+
for i, weight_name in enumerate(weight_names):
|
45
|
+
weight_tensor = getattr(module, weight_name)
|
46
|
+
|
47
|
+
if transpose_dims and transpose_dims[i]:
|
48
|
+
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
|
49
|
+
|
50
|
+
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
|
51
|
+
if not dim_is_supported(weight_tensor):
|
52
|
+
logger.warning(
|
53
|
+
f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
|
54
|
+
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
|
55
|
+
)
|
56
|
+
module.use_intel_amx_backend = False
|
57
|
+
return
|
58
|
+
|
59
|
+
packed_weight = torch.nn.Parameter(
|
60
|
+
amx_process_weight_after_loading(weight_tensor),
|
61
|
+
requires_grad=False,
|
62
|
+
)
|
63
|
+
packed_weight.__dict__ = weight_tensor.__dict__
|
64
|
+
setattr(module, weight_name, packed_weight)
|
65
|
+
|
66
|
+
module.use_intel_amx_backend = (
|
67
|
+
device == torch.device("cpu") and cpu_has_amx_support()
|
68
|
+
)
|
69
|
+
|
70
|
+
if (
|
71
|
+
module.use_intel_amx_backend
|
72
|
+
and hasattr(module, "bias")
|
73
|
+
and module.bias is not None
|
74
|
+
):
|
75
|
+
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
|
76
|
+
|
77
|
+
|
78
|
+
class PackWeightMethod:
|
79
|
+
def __init__(self, weight_names, transpose_dims=None):
|
80
|
+
self.weight_names = weight_names
|
81
|
+
self.transpose_dims = transpose_dims
|
82
|
+
|
83
|
+
def process_weights_after_loading(self, module) -> None:
|
84
|
+
_amx_process_weight_after_loading(
|
85
|
+
module, self.weight_names, self.transpose_dims
|
86
|
+
)
|