sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
72
72
  GenerateReqInput,
73
73
  GetWeightsByNameReqInput,
74
74
  InitWeightsUpdateGroupReqInput,
75
+ LoadLoRAAdapterReqInput,
75
76
  OpenSessionReqInput,
76
77
  ParseFunctionCallReq,
77
78
  ProfileReqInput,
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
80
81
  SeparateReasoningReqInput,
81
82
  SetInternalStateReq,
82
83
  SlowDownReqInput,
84
+ UnloadLoRAAdapterReqInput,
83
85
  UpdateWeightFromDiskReqInput,
84
86
  UpdateWeightsFromDistributedReqInput,
85
87
  UpdateWeightsFromTensorReqInput,
@@ -124,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
124
126
 
125
127
  @asynccontextmanager
126
128
  async def lifespan(fast_api_app: FastAPI):
127
- server_args: ServerArgs = fast_api_app.server_args
128
-
129
129
  # Initialize OpenAI serving handlers
130
130
  fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
131
131
  _global_state.tokenizer_manager, _global_state.template_manager
@@ -143,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
143
143
  _global_state.tokenizer_manager
144
144
  )
145
145
 
146
+ server_args: ServerArgs = fast_api_app.server_args
146
147
  if server_args.warmups is not None:
147
148
  await execute_warmups(
148
- server_args.warmups.split(","), _global_state.tokenizer_manager
149
+ server_args.disaggregation_mode,
150
+ server_args.warmups.split(","),
151
+ _global_state.tokenizer_manager,
149
152
  )
150
153
  logger.info("Warmup ended")
151
154
 
@@ -278,13 +281,17 @@ async def get_model_info():
278
281
  "model_path": _global_state.tokenizer_manager.model_path,
279
282
  "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
280
283
  "is_generation": _global_state.tokenizer_manager.is_generation,
284
+ "preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
281
285
  }
282
286
  return result
283
287
 
284
288
 
285
289
  @app.get("/get_server_info")
286
290
  async def get_server_info():
287
- internal_states = await _global_state.tokenizer_manager.get_internal_state()
291
+ # Returns interna states per DP.
292
+ internal_states: List[Dict[Any, Any]] = (
293
+ await _global_state.tokenizer_manager.get_internal_state()
294
+ )
288
295
  return {
289
296
  **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
290
297
  **_global_state.scheduler_info,
@@ -298,6 +305,8 @@ async def get_load():
298
305
  return await _global_state.tokenizer_manager.get_load()
299
306
 
300
307
 
308
+ # example usage:
309
+ # curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
301
310
  @app.api_route("/set_internal_state", methods=["POST", "PUT"])
302
311
  async def set_internal_state(obj: SetInternalStateReq, request: Request):
303
312
  res = await _global_state.tokenizer_manager.set_internal_state(obj)
@@ -351,8 +360,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
351
360
  obj = GenerateReqInput(
352
361
  input_embeds=input_embeds,
353
362
  sampling_params={
354
- "repetition_penalty": 1.2,
355
- "temperature": 0.2,
363
+ "temperature": 0.0,
356
364
  "max_new_tokens": 512,
357
365
  },
358
366
  )
@@ -391,16 +399,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
391
399
  return _create_error_response(e)
392
400
 
393
401
 
394
- @app.api_route(
395
- "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
396
- )
397
- async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
398
- """Endpoint for reranking documents based on query relevance."""
399
- return await raw_request.app.state.openai_serving_rerank.handle_request(
400
- request, raw_request
401
- )
402
-
403
-
404
402
  @app.api_route("/flush_cache", methods=["GET", "POST"])
405
403
  async def flush_cache():
406
404
  """Flush the radix cache."""
@@ -595,6 +593,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
595
593
  return _create_error_response(e)
596
594
 
597
595
 
596
+ @app.api_route("/load_lora_adapter", methods=["POST"])
597
+ async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
598
+ """Load a new LoRA adapter without re-launching the server."""
599
+ result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)
600
+
601
+ if result.success:
602
+ return ORJSONResponse(
603
+ result,
604
+ status_code=HTTPStatus.OK,
605
+ )
606
+ else:
607
+ return ORJSONResponse(
608
+ result,
609
+ status_code=HTTPStatus.BAD_REQUEST,
610
+ )
611
+
612
+
613
+ @app.api_route("/unload_lora_adapter", methods=["POST"])
614
+ async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
615
+ """Load a new LoRA adapter without re-launching the server."""
616
+ result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)
617
+
618
+ if result.success:
619
+ return ORJSONResponse(
620
+ result,
621
+ status_code=HTTPStatus.OK,
622
+ )
623
+ else:
624
+ return ORJSONResponse(
625
+ result,
626
+ status_code=HTTPStatus.BAD_REQUEST,
627
+ )
628
+
629
+
598
630
  @app.api_route("/open_session", methods=["GET", "POST"])
599
631
  async def open_session(obj: OpenSessionReqInput, request: Request):
600
632
  """Open a session, and return its unique session id."""
@@ -630,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
630
662
  async def abort_request(obj: AbortReq, request: Request):
631
663
  """Abort a request."""
632
664
  try:
633
- _global_state.tokenizer_manager.abort_request(rid=obj.rid)
665
+ _global_state.tokenizer_manager.abort_request(
666
+ rid=obj.rid, abort_all=obj.abort_all
667
+ )
634
668
  return Response(status_code=200)
635
669
  except Exception as e:
636
670
  return _create_error_response(e)
@@ -678,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
678
712
  return ORJSONResponse(content=response_data, status_code=200)
679
713
 
680
714
 
715
+ @app.post("/pause_generation")
716
+ async def pause_generation(request: Request):
717
+ """Pause generation."""
718
+ await _global_state.tokenizer_manager.pause_generation()
719
+ return ORJSONResponse(
720
+ content={"message": "Generation paused successfully.", "status": "ok"},
721
+ status_code=200,
722
+ )
723
+
724
+
725
+ @app.post("/continue_generation")
726
+ async def continue_generation(request: Request):
727
+ """Continue generation."""
728
+ await _global_state.tokenizer_manager.continue_generation()
729
+ return ORJSONResponse(
730
+ content={"message": "Generation continued successfully.", "status": "ok"},
731
+ status_code=200,
732
+ )
733
+
734
+
681
735
  ##### OpenAI-compatible API endpoints #####
682
736
 
683
737
 
@@ -805,6 +859,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
805
859
  )
806
860
 
807
861
 
862
+ @app.api_route(
863
+ "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
864
+ )
865
+ async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
866
+ """Endpoint for reranking documents based on query relevance."""
867
+ return await raw_request.app.state.openai_serving_rerank.handle_request(
868
+ request, raw_request
869
+ )
870
+
871
+
808
872
  def _create_error_response(e):
809
873
  return ORJSONResponse(
810
874
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
@@ -851,6 +915,15 @@ def launch_server(
851
915
  add_prometheus_middleware(app)
852
916
  enable_func_timer()
853
917
 
918
+ image_token_text = None
919
+ if (
920
+ tokenizer_manager.image_token_id is not None
921
+ and not server_args.skip_tokenizer_init
922
+ ):
923
+ image_token_text = tokenizer_manager.tokenizer.decode(
924
+ [tokenizer_manager.image_token_id]
925
+ )
926
+
854
927
  # Send a warmup request - we will create the thread launch it
855
928
  # in the lifespan after all other warmups have fired.
856
929
  warmup_thread = threading.Thread(
@@ -858,7 +931,7 @@ def launch_server(
858
931
  args=(
859
932
  server_args,
860
933
  pipe_finish_writer,
861
- _global_state.tokenizer_manager.image_token_id,
934
+ image_token_text,
862
935
  launch_callback,
863
936
  ),
864
937
  )
@@ -881,11 +954,9 @@ def launch_server(
881
954
  warmup_thread.join()
882
955
 
883
956
 
884
- def _wait_and_warmup(
957
+ def _execute_server_warmup(
885
958
  server_args: ServerArgs,
886
959
  pipe_finish_writer: Optional[multiprocessing.connection.Connection],
887
- image_token_text: str,
888
- launch_callback: Optional[Callable[[], None]] = None,
889
960
  ):
890
961
  headers = {}
891
962
  url = server_args.url()
@@ -910,7 +981,7 @@ def _wait_and_warmup(
910
981
  pipe_finish_writer.send(last_traceback)
911
982
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
912
983
  kill_process_tree(os.getpid())
913
- return
984
+ return success
914
985
 
915
986
  model_info = res.json()
916
987
 
@@ -984,12 +1055,28 @@ def _wait_and_warmup(
984
1055
  pipe_finish_writer.send(last_traceback)
985
1056
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
986
1057
  kill_process_tree(os.getpid())
987
- return
1058
+ return False
988
1059
 
989
1060
  # Debug print
990
- # logger.info(f"{res.json()=}")
1061
+ # logger.info(f"warmup request returns: {res.json()=}")
1062
+ return success
1063
+
1064
+
1065
+ def _wait_and_warmup(
1066
+ server_args: ServerArgs,
1067
+ pipe_finish_writer: Optional[multiprocessing.connection.Connection],
1068
+ image_token_text: str,
1069
+ launch_callback: Optional[Callable[[], None]] = None,
1070
+ ):
1071
+ if not server_args.skip_server_warmup:
1072
+ if not _execute_server_warmup(
1073
+ server_args,
1074
+ pipe_finish_writer,
1075
+ ):
1076
+ return
991
1077
 
992
1078
  logger.info("The server is fired up and ready to roll!")
1079
+
993
1080
  if pipe_finish_writer is not None:
994
1081
  pipe_finish_writer.send("ready")
995
1082
 
@@ -1,4 +1,3 @@
1
- import base64
2
1
  import copy
3
2
  import dataclasses
4
3
  import multiprocessing
@@ -7,6 +6,7 @@ import threading
7
6
  import time
8
7
  from typing import Any, Dict, List, Optional, Tuple, Union
9
8
 
9
+ import pybase64
10
10
  import requests
11
11
  import torch
12
12
  import torch.distributed as dist
@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel):
236
236
  index: int
237
237
  text: str
238
238
  logprobs: Optional[LogProbs] = None
239
- finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
239
+ finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
240
240
  matched_stop: Union[None, int, str] = None
241
241
  hidden_states: Optional[object] = None
242
242
 
@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
510
510
  delta: DeltaMessage
511
511
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
512
512
  finish_reason: Optional[
513
- Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
513
+ Literal[
514
+ "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
515
+ ]
514
516
  ] = None
515
517
  matched_stop: Union[None, int, str] = None
516
518
 
File without changes
@@ -3,7 +3,7 @@ from typing import Optional
3
3
 
4
4
  import torch
5
5
 
6
- from sglang.srt.managers.eplb_algorithms import deepseek, deepseek_vec
6
+ from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec
7
7
 
8
8
 
9
9
  class EplbAlgorithm(Enum):
@@ -4,10 +4,8 @@ from typing import TYPE_CHECKING, List
4
4
 
5
5
  import torch.cuda
6
6
 
7
- from sglang.srt.managers.expert_distribution import (
8
- get_global_expert_distribution_recorder,
9
- )
10
- from sglang.srt.managers.expert_location import ExpertLocationMetadata
7
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
8
+ from sglang.srt.eplb.expert_location import ExpertLocationMetadata
11
9
 
12
10
  if TYPE_CHECKING:
13
11
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -4,7 +4,7 @@ from pathlib import Path
4
4
  import torch
5
5
  from tqdm import tqdm
6
6
 
7
- from sglang.srt.managers.expert_distribution import (
7
+ from sglang.srt.eplb.expert_distribution import (
8
8
  _convert_global_physical_count_to_logical_count,
9
9
  )
10
10
 
@@ -24,7 +24,7 @@ import einops
24
24
  import torch
25
25
  import torch.distributed
26
26
 
27
- from sglang.srt.managers.expert_location import ExpertLocationMetadata
27
+ from sglang.srt.eplb.expert_location import ExpertLocationMetadata
28
28
  from sglang.srt.managers.schedule_batch import global_server_args_dict
29
29
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
30
  from sglang.srt.server_args import ServerArgs
@@ -479,10 +479,6 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
479
479
  def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
480
480
  topk_ids = topk_ids.flatten()
481
481
  mask = topk_ids != -1
482
- assert self._data[layer_idx, :].shape == topk_ids.shape, (
483
- "Shape mismatch between data and topk_ids."
484
- "Selecting expert is not supported for multiple token prediction at the moment."
485
- )
486
482
  self._data[layer_idx, :].scatter_add_(
487
483
  dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
488
484
  )
@@ -23,7 +23,7 @@ import torch.distributed
23
23
  import torch.nn.functional as F
24
24
 
25
25
  from sglang.srt.configs.model_config import ModelConfig
26
- from sglang.srt.managers import eplb_algorithms
26
+ from sglang.srt.eplb import eplb_algorithms
27
27
  from sglang.srt.model_loader import get_model_architecture
28
28
  from sglang.srt.server_args import ServerArgs
29
29
 
@@ -17,7 +17,7 @@ from typing import Literal, Optional
17
17
 
18
18
  import torch
19
19
 
20
- from sglang.srt.managers.expert_location import get_global_expert_location_metadata
20
+ from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
21
21
  from sglang.srt.managers.schedule_batch import global_server_args_dict
22
22
 
23
23
 
@@ -20,7 +20,7 @@ import torch
20
20
  import torch.distributed
21
21
  from torch.distributed import P2POp
22
22
 
23
- from sglang.srt.managers.expert_location import (
23
+ from sglang.srt.eplb.expert_location import (
24
24
  ExpertLocationMetadata,
25
25
  get_global_expert_location_metadata,
26
26
  )
@@ -30,6 +30,9 @@ from sglang.srt.utils import get_bool_env_var
30
30
  logger = logging.getLogger(__name__)
31
31
 
32
32
 
33
+ _LOG_INPUT = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_INPUT")
34
+
35
+
33
36
  class ExpertLocationUpdater:
34
37
  def __init__(self):
35
38
  self._first_execution = True
@@ -175,6 +178,19 @@ def update_expert_weights_single_layer(
175
178
  assert isinstance(old_physical_to_logical_map, list)
176
179
  assert isinstance(new_physical_to_logical_map, list)
177
180
 
181
+ if _LOG_INPUT:
182
+ logger.info(
183
+ "update_expert_weights_single_layer "
184
+ f"{[x.shape for x in routed_experts_weights]=} "
185
+ f"{[x.shape for x in temp_buffers]=} "
186
+ f"{old_physical_to_logical_map=} "
187
+ f"{new_physical_to_logical_map=} "
188
+ f"{num_local_physical_experts=} "
189
+ f"{num_gpu_per_node=} "
190
+ f"{rank=} "
191
+ f"{world_size=} "
192
+ )
193
+
178
194
  output_logs = [] if debug else None
179
195
 
180
196
  num_physical_experts = len(old_physical_to_logical_map)
@@ -42,7 +42,7 @@ from sglang.srt.configs import (
42
42
  )
43
43
  from sglang.srt.configs.internvl import InternVLChatConfig
44
44
  from sglang.srt.connector import create_remote_connector
45
- from sglang.srt.utils import is_remote_url
45
+ from sglang.srt.utils import is_remote_url, lru_cache_frozenset
46
46
 
47
47
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
48
48
  ChatGLMConfig.model_type: ChatGLMConfig,
@@ -103,6 +103,7 @@ def get_hf_text_config(config: PretrainedConfig):
103
103
  return config
104
104
 
105
105
 
106
+ @lru_cache_frozenset(maxsize=32)
106
107
  def get_config(
107
108
  model: str,
108
109
  trust_remote_code: bool,
@@ -46,11 +46,11 @@ _is_cpu = is_cpu()
46
46
  if _is_cuda:
47
47
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
48
48
 
49
- logger = logging.getLogger(__name__)
50
-
51
49
  if is_npu():
52
50
  import torch_npu
53
51
 
52
+ logger = logging.getLogger(__name__)
53
+
54
54
 
55
55
  class SiluAndMul(CustomOp):
56
56
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -0,0 +1,86 @@
1
+ import logging
2
+
3
+ import torch
4
+
5
+ from sglang.srt.utils import cpu_has_amx_support
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def amx_process_weight_after_loading(weight):
11
+ if weight.device != torch.device("cpu"):
12
+ return weight
13
+ if not cpu_has_amx_support():
14
+ return weight
15
+
16
+ return torch.ops.sgl_kernel.convert_weight_packed(weight)
17
+
18
+
19
+ # TODO: currently gemm kernel has the below requirements:
20
+ # OC % TILE_N == 0, where TILE_N = 16
21
+ # IC % TILE_K == 0, where TILE_K = 32
22
+ def dim_is_supported(weight):
23
+ TILE_N = 16
24
+ TILE_K = 32
25
+ ndim = weight.ndim
26
+ OC = weight.size(1) if ndim == 3 else weight.size(0)
27
+ IC = weight.size(2) if ndim == 3 else weight.size(1)
28
+ return OC % TILE_N == 0 and IC % TILE_K == 0
29
+
30
+
31
+ def _amx_process_weight_after_loading(
32
+ module, weight_names, transpose_dims=None
33
+ ) -> None:
34
+ # Pack weight for get better performance on CPU
35
+ devices = {getattr(module, weight_name).device for weight_name in weight_names}
36
+ assert len(devices) == 1, f"Expects all weights to be on the same device"
37
+ device = devices.pop()
38
+
39
+ if transpose_dims:
40
+ assert len(weight_names) == len(
41
+ transpose_dims
42
+ ), "len(weight_names) should be equal to len(transpose_dims)"
43
+
44
+ for i, weight_name in enumerate(weight_names):
45
+ weight_tensor = getattr(module, weight_name)
46
+
47
+ if transpose_dims and transpose_dims[i]:
48
+ weight_tensor = weight_tensor.transpose(*transpose_dims[i])
49
+
50
+ # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
51
+ if not dim_is_supported(weight_tensor):
52
+ logger.warning(
53
+ f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
54
+ f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
55
+ )
56
+ module.use_intel_amx_backend = False
57
+ return
58
+
59
+ packed_weight = torch.nn.Parameter(
60
+ amx_process_weight_after_loading(weight_tensor),
61
+ requires_grad=False,
62
+ )
63
+ packed_weight.__dict__ = weight_tensor.__dict__
64
+ setattr(module, weight_name, packed_weight)
65
+
66
+ module.use_intel_amx_backend = (
67
+ device == torch.device("cpu") and cpu_has_amx_support()
68
+ )
69
+
70
+ if (
71
+ module.use_intel_amx_backend
72
+ and hasattr(module, "bias")
73
+ and module.bias is not None
74
+ ):
75
+ module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
76
+
77
+
78
+ class PackWeightMethod:
79
+ def __init__(self, weight_names, transpose_dims=None):
80
+ self.weight_names = weight_names
81
+ self.transpose_dims = transpose_dims
82
+
83
+ def process_weights_after_loading(self, module) -> None:
84
+ _amx_process_weight_after_loading(
85
+ module, self.weight_names, self.transpose_dims
86
+ )