sglang 0.4.1.post1__py3-none-any.whl → 0.4.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. sglang/bench_offline_throughput.py +1 -0
  2. sglang/srt/configs/model_config.py +11 -2
  3. sglang/srt/layers/attention/__init__.py +0 -1
  4. sglang/srt/layers/attention/flashinfer_backend.py +54 -41
  5. sglang/srt/layers/logits_processor.py +30 -2
  6. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  7. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  8. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  71. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  72. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  73. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  74. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
  75. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  76. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -26
  77. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  78. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  84. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  85. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  89. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  91. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  93. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  95. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  96. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  97. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  98. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  100. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  103. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  104. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  105. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  106. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  107. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  108. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  109. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  112. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  113. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  114. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  115. sglang/srt/layers/quantization/fp8.py +42 -2
  116. sglang/srt/layers/quantization/fp8_kernel.py +77 -18
  117. sglang/srt/layers/quantization/fp8_utils.py +8 -2
  118. sglang/srt/managers/detokenizer_manager.py +2 -0
  119. sglang/srt/managers/io_struct.py +40 -9
  120. sglang/srt/managers/schedule_batch.py +22 -15
  121. sglang/srt/managers/scheduler.py +69 -21
  122. sglang/srt/managers/session_controller.py +102 -27
  123. sglang/srt/managers/tokenizer_manager.py +48 -10
  124. sglang/srt/managers/tp_worker.py +7 -0
  125. sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
  126. sglang/srt/model_executor/forward_batch_info.py +42 -3
  127. sglang/srt/model_executor/model_runner.py +4 -0
  128. sglang/srt/models/llama.py +11 -0
  129. sglang/srt/models/llama_eagle.py +132 -0
  130. sglang/srt/openai_api/adapter.py +60 -2
  131. sglang/srt/openai_api/protocol.py +48 -0
  132. sglang/srt/server.py +26 -3
  133. sglang/srt/server_args.py +24 -30
  134. sglang/srt/speculative/spec_info.py +19 -0
  135. sglang/srt/utils.py +62 -0
  136. sglang/version.py +1 -1
  137. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/METADATA +3 -3
  138. sglang-0.4.1.post3.dist-info/RECORD +305 -0
  139. sglang-0.4.1.post1.dist-info/RECORD +0 -195
  140. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/LICENSE +0 -0
  141. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
29
29
 
30
30
  import dataclasses
31
31
  import logging
32
- from typing import List, Optional, Tuple, Union
32
+ from typing import List, Optional, Set, Tuple, Union
33
33
 
34
34
  import numpy as np
35
35
  import torch
@@ -209,6 +209,7 @@ class Req:
209
209
  lora_path: Optional[str] = None,
210
210
  input_embeds: Optional[List[List[float]]] = None,
211
211
  session_id: Optional[str] = None,
212
+ eos_token_ids: Optional[Set[int]] = None,
212
213
  ):
213
214
  # Input and output info
214
215
  self.rid = rid
@@ -236,6 +237,7 @@ class Req:
236
237
  self.finished_reason = None
237
238
  self.to_abort = False
238
239
  self.stream = stream
240
+ self.eos_token_ids = eos_token_ids
239
241
 
240
242
  # For incremental decoding
241
243
  # ----- | --------- read_ids -------|
@@ -395,18 +397,23 @@ class Req:
395
397
 
396
398
  last_token_id = self.output_ids[-1]
397
399
 
398
- matched_eos = False
399
-
400
- # Check stop token ids
401
- if self.sampling_params.stop_token_ids:
402
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
403
- if self.tokenizer is not None:
404
- matched_eos |= last_token_id == self.tokenizer.eos_token_id
405
- if self.tokenizer.additional_stop_token_ids:
406
- matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
407
- if matched_eos and not self.sampling_params.ignore_eos:
408
- self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
409
- return
400
+ if not self.sampling_params.ignore_eos:
401
+ matched_eos = False
402
+
403
+ # Check stop token ids
404
+ if self.sampling_params.stop_token_ids:
405
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
406
+ if self.eos_token_ids:
407
+ matched_eos |= last_token_id in self.eos_token_ids
408
+ if self.tokenizer is not None:
409
+ matched_eos |= last_token_id == self.tokenizer.eos_token_id
410
+ if self.tokenizer.additional_stop_token_ids:
411
+ matched_eos |= (
412
+ last_token_id in self.tokenizer.additional_stop_token_ids
413
+ )
414
+ if matched_eos:
415
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
416
+ return
410
417
 
411
418
  # Check stop strings
412
419
  if len(self.sampling_params.stop_strs) > 0:
@@ -836,8 +843,8 @@ class ScheduleBatch:
836
843
  # TODO (lianmin): Revisit this. It should be seq_len - 1
837
844
  self.extend_logprob_start_lens.extend([0] * running_bs)
838
845
 
839
- def check_decode_mem(self):
840
- bs = len(self.reqs)
846
+ def check_decode_mem(self, buf_multiplier=1):
847
+ bs = len(self.reqs) * buf_multiplier
841
848
  if self.token_to_kv_pool.available_size() >= bs:
842
849
  return True
843
850
 
@@ -22,7 +22,7 @@ import warnings
22
22
  from collections import deque
23
23
  from concurrent import futures
24
24
  from types import SimpleNamespace
25
- from typing import Callable, Dict, List, Optional, Tuple
25
+ from typing import Dict, List, Optional, Tuple
26
26
 
27
27
  import psutil
28
28
  import setproctitle
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
52
52
  UpdateWeightFromDiskReqOutput,
53
53
  UpdateWeightsFromDistributedReqInput,
54
54
  UpdateWeightsFromDistributedReqOutput,
55
+ UpdateWeightsFromTensorReqInput,
56
+ UpdateWeightsFromTensorReqOutput,
55
57
  )
56
58
  from sglang.srt.managers.schedule_batch import (
57
59
  FINISH_ABORT,
@@ -88,7 +90,7 @@ from sglang.utils import get_exception_traceback
88
90
 
89
91
  logger = logging.getLogger(__name__)
90
92
 
91
- # Test retract decode
93
+ # Test retract decode for debugging purposes
92
94
  test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
93
95
 
94
96
 
@@ -127,12 +129,12 @@ class Scheduler:
127
129
  )
128
130
 
129
131
  if server_args.skip_tokenizer_init:
130
- # Directly send to the tokenizer/api
132
+ # Directly send to the TokenizerManager
131
133
  self.send_to_detokenizer = get_zmq_socket(
132
134
  context, zmq.PUSH, port_args.tokenizer_ipc_name
133
135
  )
134
136
  else:
135
- # Send to the detokenizer
137
+ # Send to the DetokenizerManager
136
138
  self.send_to_detokenizer = get_zmq_socket(
137
139
  context, zmq.PUSH, port_args.detokenizer_ipc_name
138
140
  )
@@ -383,7 +385,8 @@ class Scheduler:
383
385
  self.process_input_requests(recv_reqs)
384
386
 
385
387
  batch = self.get_next_batch_to_run()
386
- if self.server_args.enable_dp_attention:
388
+
389
+ if self.server_args.enable_dp_attention: # TODO: simplify this
387
390
  batch = self.prepare_dp_attn_batch(batch)
388
391
 
389
392
  self.cur_batch = batch
@@ -392,7 +395,7 @@ class Scheduler:
392
395
  result = self.run_batch(batch)
393
396
  self.process_batch_result(batch, result)
394
397
  else:
395
- # Self-check and re-init some states when the server is idle
398
+ # When the server is idle, so self-check and re-init some states
396
399
  self.check_memory()
397
400
  self.new_token_ratio = self.init_new_token_ratio
398
401
 
@@ -409,12 +412,13 @@ class Scheduler:
409
412
 
410
413
  batch = self.get_next_batch_to_run()
411
414
  self.cur_batch = batch
415
+
412
416
  if batch:
413
417
  result = self.run_batch(batch)
414
418
  result_queue.append((batch.copy(), result))
415
419
 
416
420
  if self.last_batch is None:
417
- # A dummy first batch to start the pipeline for overlap scheduler.
421
+ # Create a dummy first batch to start the pipeline for overlap scheduler.
418
422
  # It is now used for triggering the sampling_info_done event.
419
423
  tmp_batch = ScheduleBatch(
420
424
  reqs=None,
@@ -424,19 +428,21 @@ class Scheduler:
424
428
  self.process_batch_result(tmp_batch, None)
425
429
 
426
430
  if self.last_batch:
431
+ # Process the results of the last batch
427
432
  tmp_batch, tmp_result = result_queue.popleft()
428
433
  tmp_batch.next_batch_sampling_info = (
429
434
  self.tp_worker.cur_sampling_info if batch else None
430
435
  )
431
436
  self.process_batch_result(tmp_batch, tmp_result)
432
437
  elif batch is None:
433
- # Self-check and re-init some states when the server is idle
438
+ # When the server is idle, so self-check and re-init some states
434
439
  self.check_memory()
435
440
  self.new_token_ratio = self.init_new_token_ratio
436
441
 
437
442
  self.last_batch = batch
438
443
 
439
- def recv_requests(self):
444
+ def recv_requests(self) -> List[Req]:
445
+ """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
440
446
  if self.tp_rank == 0 or self.server_args.enable_dp_attention:
441
447
  recv_reqs = []
442
448
 
@@ -478,6 +484,11 @@ class Scheduler:
478
484
  self.send_to_tokenizer.send_pyobj(
479
485
  UpdateWeightsFromDistributedReqOutput(success, message)
480
486
  )
487
+ elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
488
+ success, message = self.update_weights_from_tensor(recv_req)
489
+ self.send_to_tokenizer.send_pyobj(
490
+ UpdateWeightsFromTensorReqOutput(success, message)
491
+ )
481
492
  elif isinstance(recv_req, GetWeightsByNameReqInput):
482
493
  parameter = self.get_weights_by_name(recv_req)
483
494
  self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
@@ -487,8 +498,10 @@ class Scheduler:
487
498
  else:
488
499
  self.stop_profile()
489
500
  elif isinstance(recv_req, OpenSessionReqInput):
490
- session_id = self.open_session(recv_req)
491
- self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
501
+ session_id, success = self.open_session(recv_req)
502
+ self.send_to_tokenizer.send_pyobj(
503
+ OpenSessionReqOutput(session_id=session_id, success=success)
504
+ )
492
505
  elif isinstance(recv_req, CloseSessionReqInput):
493
506
  self.close_session(recv_req)
494
507
  else:
@@ -499,7 +512,11 @@ class Scheduler:
499
512
  recv_req: TokenizedGenerateReqInput,
500
513
  ):
501
514
  # Create a new request
502
- if recv_req.session_id is None or recv_req.session_id not in self.sessions:
515
+ if (
516
+ recv_req.session_params is None
517
+ or recv_req.session_params.id is None
518
+ or recv_req.session_params.id not in self.sessions
519
+ ):
503
520
 
504
521
  if recv_req.input_embeds is not None:
505
522
  # Generate fake input_ids based on the length of input_embeds
@@ -517,18 +534,22 @@ class Scheduler:
517
534
  stream=recv_req.stream,
518
535
  lora_path=recv_req.lora_path,
519
536
  input_embeds=recv_req.input_embeds,
537
+ eos_token_ids=self.model_config.hf_eos_token_id,
520
538
  )
521
539
  req.tokenizer = self.tokenizer
522
540
 
523
- if recv_req.session_id is not None:
541
+ if (
542
+ recv_req.session_params is not None
543
+ and recv_req.session_params.id is not None
544
+ ):
524
545
  req.finished_reason = FINISH_ABORT(
525
- f"Invalid request: session id {recv_req.session_id} does not exist"
546
+ f"Invalid request: session id {recv_req.session_params.id} does not exist"
526
547
  )
527
548
  self.waiting_queue.append(req)
528
549
  return
529
550
  else:
530
- # Create a new request from a previsou session
531
- session = self.sessions[recv_req.session_id]
551
+ # Create a new request from a previous session
552
+ session = self.sessions[recv_req.session_params.id]
532
553
  req = session.create_req(recv_req, self.tokenizer)
533
554
  if isinstance(req.finished_reason, FINISH_ABORT):
534
555
  self.waiting_queue.append(req)
@@ -804,6 +825,8 @@ class Scheduler:
804
825
  if res == AddReqResult.NO_TOKEN:
805
826
  self.batch_is_full = True
806
827
  break
828
+ if self.server_args.prefill_only_one_req:
829
+ break
807
830
 
808
831
  # Update waiting queue
809
832
  can_run_list = adder.can_run_list
@@ -1195,6 +1218,7 @@ class Scheduler:
1195
1218
  decode_ids_list = []
1196
1219
  read_offsets = []
1197
1220
  output_ids = []
1221
+ origin_input_ids = []
1198
1222
 
1199
1223
  skip_special_tokens = []
1200
1224
  spaces_between_special_tokens = []
@@ -1243,8 +1267,14 @@ class Scheduler:
1243
1267
  decode_ids, read_offset = req.init_incremental_detokenize()
1244
1268
  decode_ids_list.append(decode_ids)
1245
1269
  read_offsets.append(read_offset)
1246
- if self.skip_tokenizer_init:
1270
+ if self.skip_tokenizer_init or self.server_args.return_token_ids:
1247
1271
  output_ids.append(req.output_ids)
1272
+ else:
1273
+ output_ids = None
1274
+ if self.server_args.return_token_ids:
1275
+ origin_input_ids.append(req.origin_input_ids)
1276
+ else:
1277
+ origin_input_ids = None
1248
1278
  skip_special_tokens.append(req.sampling_params.skip_special_tokens)
1249
1279
  spaces_between_special_tokens.append(
1250
1280
  req.sampling_params.spaces_between_special_tokens
@@ -1276,6 +1306,7 @@ class Scheduler:
1276
1306
  decoded_texts,
1277
1307
  decode_ids_list,
1278
1308
  read_offsets,
1309
+ origin_input_ids,
1279
1310
  output_ids,
1280
1311
  skip_special_tokens,
1281
1312
  spaces_between_special_tokens,
@@ -1457,6 +1488,17 @@ class Scheduler:
1457
1488
  logger.error(message)
1458
1489
  return success, message
1459
1490
 
1491
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
1492
+ """Update the online model parameter from tensors."""
1493
+ success, message = self.tp_worker.update_weights_from_tensor(recv_req)
1494
+ # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
1495
+ if success:
1496
+ flash_cache_success = self.flush_cache()
1497
+ assert flash_cache_success, "Cache flush failed after updating weights"
1498
+ else:
1499
+ logger.error(message)
1500
+ return success, message
1501
+
1460
1502
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
1461
1503
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1462
1504
  return parameter
@@ -1475,16 +1517,20 @@ class Scheduler:
1475
1517
  )
1476
1518
  logger.info("Profiler is done")
1477
1519
 
1478
- def open_session(self, recv_req: OpenSessionReqInput) -> str:
1520
+ def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
1479
1521
  # handle error
1480
1522
  session_id = recv_req.session_id
1481
1523
  if session_id in self.sessions:
1482
1524
  logger.warning(f"session id {session_id} already exist, cannot open.")
1525
+ return session_id, False
1526
+ elif session_id is None:
1527
+ logger.warning(f"session id is None, cannot open.")
1528
+ return session_id, False
1483
1529
  else:
1484
1530
  self.sessions[session_id] = Session(
1485
1531
  recv_req.capacity_of_str_len, session_id
1486
1532
  )
1487
- return session_id
1533
+ return session_id, True
1488
1534
 
1489
1535
  def close_session(self, recv_req: CloseSessionReqInput):
1490
1536
  # handle error
@@ -1509,18 +1555,20 @@ def run_scheduler_process(
1509
1555
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
1510
1556
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
1511
1557
 
1558
+ # Configue the logger
1512
1559
  if dp_rank is None:
1513
1560
  configure_logger(server_args, prefix=f" TP{tp_rank}")
1514
1561
  else:
1515
1562
  configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1563
+ suppress_other_loggers()
1516
1564
 
1517
- # set cpu affinity to this gpu process
1565
+ # Set cpu affinity to this gpu process
1518
1566
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1519
1567
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1520
1568
 
1521
- suppress_other_loggers()
1522
1569
  parent_process = psutil.Process().parent()
1523
1570
 
1571
+ # Create a scheduler and run the event loop
1524
1572
  try:
1525
1573
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1526
1574
  pipe_writer.send(
@@ -10,41 +10,116 @@
10
10
  # limitations under the License.
11
11
  # ==============================================================================
12
12
 
13
+ import logging
13
14
  import uuid
15
+ from typing import Dict, Optional
14
16
 
15
17
  from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
16
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
18
+ from sglang.srt.managers.schedule_batch import Req
19
+
20
+
21
+ class SessionReqNode:
22
+ def __init__(self, req, parent=None, childs=None):
23
+ self.req = req
24
+ self.parent = parent
25
+ if parent is not None:
26
+ parent.childs.append(self)
27
+ self.childs = [] if not childs else childs
28
+
29
+ def clear_childs(self, req_dict):
30
+ for req_node in self.childs:
31
+ req_node.clear(req_dict)
32
+ self.childs = []
33
+
34
+ def clear(self, req_dict):
35
+ for req_node in self.childs:
36
+ req_node.clear(req_dict)
37
+
38
+ if self.req.finished_reason == None:
39
+ self.req.to_abort = True
40
+ del req_dict[self.req.rid]
41
+
42
+ def abort(self):
43
+ if self.req.finished_reason == None:
44
+ self.req.to_abort = True
45
+
46
+ def __str__(self):
47
+ return self._str_helper(self.req.rid)
48
+
49
+ def _str_helper(self, prefix=""):
50
+ if len(self.childs) == 0:
51
+ return prefix + "\n"
52
+ else:
53
+ origin_prefix = prefix
54
+ prefix += " -- " + self.childs[0].req.rid
55
+ ret = self.childs[0]._str_helper(prefix)
56
+ for child in self.childs[1:]:
57
+ prefix = " " * len(origin_prefix) + " \- " + child.req.rid
58
+ ret += child._str_helper(prefix)
59
+ return ret
17
60
 
18
61
 
19
62
  class Session:
20
- def __init__(self, capacity_of_str_len: int, session_id: str = None):
63
+ def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
21
64
  self.session_id = session_id if session_id is not None else uuid.uuid4().hex
22
65
  self.capacity_of_str_len = capacity_of_str_len
23
- self.reqs: List[Req] = []
66
+ self.req_nodes: Dict[str, SessionReqNode] = {}
24
67
 
25
68
  def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
26
- if req.session_rid is not None:
27
- while len(self.reqs) > 0:
28
- if self.reqs[-1].rid == req.session_rid:
29
- break
30
- self.reqs = self.reqs[:-1]
69
+ assert req.session_params is not None
70
+ session_params = req.session_params
71
+
72
+ last_req_node = None
73
+ last_req = None
74
+ abort = False
75
+ if session_params.replace:
76
+ if session_params.rid is None:
77
+ for _, req_node in self.req_nodes.items():
78
+ req_node.clear(self.req_nodes)
79
+ else:
80
+ if session_params.rid not in self.req_nodes:
81
+ abort = True
82
+ else:
83
+ last_req_node = self.req_nodes[session_params.rid]
84
+ last_req_node.abort()
85
+ last_req = last_req_node.req
86
+ last_req_node.clear_childs(self.req_nodes)
31
87
  else:
32
- self.reqs = []
33
- if len(self.reqs) > 0:
88
+ if session_params.rid is not None:
89
+ if session_params.rid not in self.req_nodes:
90
+ abort = True
91
+ else:
92
+ last_req_node = self.req_nodes[session_params.rid]
93
+ last_req = last_req_node.req
94
+ if not last_req.finished():
95
+ logging.warning(
96
+ "The request in a session is appending to a request that hasn't finished."
97
+ )
98
+ abort = True
99
+
100
+ if last_req is not None:
101
+ # trim bos token if it is an append
102
+ if req.input_ids[0] == tokenizer.bos_token_id:
103
+ req.input_ids = req.input_ids[1:]
104
+
34
105
  input_ids = (
35
- self.reqs[-1].origin_input_ids
36
- + self.reqs[-1].output_ids[
37
- : self.reqs[-1].sampling_params.max_new_tokens
38
- ]
39
- + req.input_ids
106
+ last_req.origin_input_ids
107
+ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
40
108
  )
109
+ if session_params.offset and session_params.offset != 0:
110
+ input_ids = input_ids[: session_params.offset] + req.input_ids
111
+ else:
112
+ input_ids += req.input_ids
41
113
  input_ids_unpadded = (
42
- self.reqs[-1].origin_input_ids_unpadded
43
- + self.reqs[-1].output_ids[
44
- : self.reqs[-1].sampling_params.max_new_tokens
45
- ]
46
- + req.input_ids
114
+ last_req.origin_input_ids_unpadded
115
+ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
47
116
  )
117
+ if session_params.offset and session_params.offset != 0:
118
+ input_ids_unpadded = (
119
+ input_ids_unpadded[: session_params.offset] + req.input_ids
120
+ )
121
+ else:
122
+ input_ids_unpadded += req.input_ids
48
123
  else:
49
124
  input_ids = req.input_ids
50
125
  input_ids_unpadded = req.input_ids
@@ -57,13 +132,13 @@ class Session:
57
132
  lora_path=req.lora_path,
58
133
  session_id=self.session_id,
59
134
  )
60
- if len(self.reqs) > 0:
61
- new_req.image_inputs = self.reqs[-1].image_inputs
135
+ if last_req is not None:
136
+ new_req.image_inputs = last_req.image_inputs
62
137
  new_req.tokenizer = tokenizer
63
- if req.session_rid is not None and len(self.reqs) == 0:
64
- new_req.finished_reason = FINISH_ABORT(
65
- f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
66
- )
138
+ if abort:
139
+ new_req.to_abort = True
67
140
  else:
68
- self.reqs.append(new_req)
141
+ new_req_node = SessionReqNode(new_req, last_req_node)
142
+ self.req_nodes[req.rid] = new_req_node
143
+
69
144
  return new_req
@@ -53,12 +53,15 @@ from sglang.srt.managers.io_struct import (
53
53
  OpenSessionReqInput,
54
54
  OpenSessionReqOutput,
55
55
  ProfileReq,
56
+ SessionParams,
56
57
  TokenizedEmbeddingReqInput,
57
58
  TokenizedGenerateReqInput,
58
59
  UpdateWeightFromDiskReqInput,
59
60
  UpdateWeightFromDiskReqOutput,
60
61
  UpdateWeightsFromDistributedReqInput,
61
62
  UpdateWeightsFromDistributedReqOutput,
63
+ UpdateWeightsFromTensorReqInput,
64
+ UpdateWeightsFromTensorReqOutput,
62
65
  )
63
66
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
64
67
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -179,6 +182,9 @@ class TokenizerManager:
179
182
  self.update_weights_from_distributed_communicator = _Communicator(
180
183
  self.send_to_scheduler, server_args.dp_size
181
184
  )
185
+ self.update_weights_from_tensor_communicator = _Communicator(
186
+ self.send_to_scheduler, server_args.dp_size
187
+ )
182
188
  self.get_weights_by_name_communicator = _Communicator(
183
189
  self.send_to_scheduler, server_args.dp_size
184
190
  )
@@ -259,8 +265,9 @@ class TokenizerManager:
259
265
  return_logprob = obj.return_logprob
260
266
  logprob_start_len = obj.logprob_start_len
261
267
  top_logprobs_num = obj.top_logprobs_num
262
- session_id = obj.session[0] if obj.session else None
263
- session_rid = obj.session[1] if obj.session else None
268
+ session_params = (
269
+ SessionParams(**obj.session_params) if obj.session_params else None
270
+ )
264
271
 
265
272
  if obj.input_ids is not None and len(input_ids) >= self.context_len:
266
273
  raise ValueError(
@@ -287,8 +294,7 @@ class TokenizerManager:
287
294
  obj.stream,
288
295
  lora_path=obj.lora_path,
289
296
  input_embeds=input_embeds,
290
- session_id=session_id,
291
- session_rid=session_rid,
297
+ session_params=session_params,
292
298
  )
293
299
  elif isinstance(obj, EmbeddingReqInput):
294
300
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -515,6 +521,22 @@ class TokenizerManager:
515
521
  result = (await self.update_weights_from_distributed_communicator(obj))[0]
516
522
  return result.success, result.message
517
523
 
524
+ async def update_weights_from_tensor(
525
+ self,
526
+ obj: UpdateWeightsFromTensorReqInput,
527
+ request: Optional[fastapi.Request] = None,
528
+ ) -> Tuple[bool, str]:
529
+ self.auto_create_handle_loop()
530
+ assert (
531
+ self.server_args.dp_size == 1
532
+ ), "dp_size must be for update weights from distributed"
533
+
534
+ # This means that weight sync
535
+ # cannot run while requests are in progress.
536
+ async with self.model_update_lock.writer_lock:
537
+ result = (await self.update_weights_from_tensor_communicator(obj))[0]
538
+ return result.success, result.message
539
+
518
540
  async def get_weights_by_name(
519
541
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
520
542
  ):
@@ -531,12 +553,16 @@ class TokenizerManager:
531
553
  ):
532
554
  self.auto_create_handle_loop()
533
555
 
534
- session_id = uuid.uuid4().hex
535
- obj.session_id = session_id
556
+ if obj.session_id is None:
557
+ obj.session_id = uuid.uuid4().hex
558
+ elif obj.session_id in self.session_futures:
559
+ return None
560
+
536
561
  self.send_to_scheduler.send_pyobj(obj)
537
- self.session_futures[session_id] = asyncio.Future()
538
- session_id = await self.session_futures[session_id]
539
- del self.session_futures[session_id]
562
+
563
+ self.session_futures[obj.session_id] = asyncio.Future()
564
+ session_id = await self.session_futures[obj.session_id]
565
+ del self.session_futures[obj.session_id]
540
566
  return session_id
541
567
 
542
568
  async def close_session(
@@ -637,6 +663,13 @@ class TokenizerManager:
637
663
  "text": recv_obj.output_strs[i],
638
664
  "meta_info": meta_info,
639
665
  }
666
+ if self.server_args.return_token_ids:
667
+ out_dict.update(
668
+ {
669
+ "input_ids": recv_obj.origin_input_ids[i],
670
+ "output_ids": recv_obj.output_ids[i],
671
+ }
672
+ )
640
673
  elif isinstance(recv_obj, BatchTokenIDOut):
641
674
  out_dict = {
642
675
  "token_ids": recv_obj.output_ids[i],
@@ -688,7 +721,7 @@ class TokenizerManager:
688
721
  )
689
722
  elif isinstance(recv_obj, OpenSessionReqOutput):
690
723
  self.session_futures[recv_obj.session_id].set_result(
691
- recv_obj.session_id
724
+ recv_obj.session_id if recv_obj.success else None
692
725
  )
693
726
  elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
694
727
  if self.server_args.dp_size == 1:
@@ -708,6 +741,11 @@ class TokenizerManager:
708
741
  self.server_args.dp_size == 1
709
742
  ), "dp_size must be 1 for update weights from distributed"
710
743
  self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
744
+ elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
745
+ assert (
746
+ self.server_args.dp_size == 1
747
+ ), "dp_size must be 1 for update weights from distributed"
748
+ self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
711
749
  elif isinstance(recv_obj, GetWeightsByNameReqOutput):
712
750
  self.get_weights_by_name_communicator.handle_recv(recv_obj)
713
751
  else:
@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
24
24
  InitWeightsUpdateGroupReqInput,
25
25
  UpdateWeightFromDiskReqInput,
26
26
  UpdateWeightsFromDistributedReqInput,
27
+ UpdateWeightsFromTensorReqInput,
27
28
  )
28
29
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
29
30
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -188,6 +189,12 @@ class TpModelWorker:
188
189
  )
189
190
  return success, message
190
191
 
192
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
193
+ success, message = self.model_runner.update_weights_from_tensor(
194
+ recv_req.name, recv_req.tensor
195
+ )
196
+ return success, message
197
+
191
198
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
192
199
  parameter = self.model_runner.get_weights_by_name(
193
200
  recv_req.name, recv_req.truncate_size
@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
28
28
  InitWeightsUpdateGroupReqInput,
29
29
  UpdateWeightFromDiskReqInput,
30
30
  UpdateWeightsFromDistributedReqInput,
31
+ UpdateWeightsFromTensorReqInput,
31
32
  )
32
33
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
33
34
  from sglang.srt.managers.tp_worker import TpModelWorker
@@ -225,6 +226,10 @@ class TpModelWorkerClient:
225
226
  success, message = self.worker.update_weights_from_distributed(recv_req)
226
227
  return success, message
227
228
 
229
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
230
+ success, message = self.worker.update_weights_from_tensor(recv_req)
231
+ return success, message
232
+
228
233
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
229
234
  return self.worker.get_weights_by_name(recv_req)
230
235