sglang 0.3.4.post1__tar.gz → 0.3.4.post2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. {sglang-0.3.4.post1/sglang.egg-info → sglang-0.3.4.post2}/PKG-INFO +13 -14
  2. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/README.md +12 -13
  3. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/pyproject.toml +1 -1
  4. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/configs/model_config.py +25 -2
  5. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/constrained/fsm_cache.py +10 -3
  6. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/hf_transformers_utils.py +14 -0
  7. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  8. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/logits_processor.py +5 -5
  9. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/rotary_embedding.py +15 -48
  10. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/sampler.py +51 -39
  11. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/data_parallel_controller.py +1 -1
  12. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/detokenizer_manager.py +4 -0
  13. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/io_struct.py +10 -0
  14. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/schedule_batch.py +13 -3
  15. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/scheduler.py +8 -2
  16. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/tokenizer_manager.py +14 -0
  17. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  18. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/mem_cache/memory_pool.py +10 -3
  19. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/model_executor/cuda_graph_runner.py +29 -21
  20. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/model_executor/forward_batch_info.py +6 -9
  21. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/model_executor/model_runner.py +2 -2
  22. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  23. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/sampling/sampling_params.py +5 -7
  24. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/server.py +12 -0
  25. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/run_eval.py +2 -0
  26. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  27. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/test_utils.py +100 -3
  28. sglang-0.3.4.post2/sglang/version.py +1 -0
  29. {sglang-0.3.4.post1 → sglang-0.3.4.post2/sglang.egg-info}/PKG-INFO +13 -14
  30. sglang-0.3.4.post1/sglang/version.py +0 -1
  31. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/LICENSE +0 -0
  32. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/setup.cfg +0 -0
  33. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/__init__.py +0 -0
  34. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/api.py +0 -0
  35. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/bench_latency.py +0 -0
  36. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/bench_server_latency.py +0 -0
  37. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/bench_serving.py +0 -0
  38. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/check_env.py +0 -0
  39. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/global_config.py +0 -0
  40. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/__init__.py +0 -0
  41. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/backend/__init__.py +0 -0
  42. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/backend/anthropic.py +0 -0
  43. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/backend/base_backend.py +0 -0
  44. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/backend/litellm.py +0 -0
  45. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/backend/openai.py +0 -0
  46. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/backend/runtime_endpoint.py +0 -0
  47. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/backend/vertexai.py +0 -0
  48. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/chat_template.py +0 -0
  49. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/choices.py +0 -0
  50. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/compiler.py +0 -0
  51. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/interpreter.py +0 -0
  52. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/ir.py +0 -0
  53. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/lang/tracer.py +0 -0
  54. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/launch_server.py +0 -0
  55. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/launch_server_llavavid.py +0 -0
  56. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/configs/__init__.py +0 -0
  57. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/configs/exaone.py +0 -0
  58. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/configs/qwen2vl.py +0 -0
  59. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/constrained/__init__.py +0 -0
  60. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/constrained/base_tool_cache.py +0 -0
  61. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/constrained/jump_forward.py +0 -0
  62. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/conversation.py +0 -0
  63. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/activation.py +0 -0
  64. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/attention/__init__.py +0 -0
  65. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/attention/double_sparsity_backend.py +0 -0
  66. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/attention/triton_backend.py +0 -0
  67. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/attention/triton_ops/decode_attention.py +0 -0
  68. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +0 -0
  69. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/attention/triton_ops/extend_attention.py +0 -0
  70. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/attention/triton_ops/prefill_attention.py +0 -0
  71. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/fused_moe/__init__.py +0 -0
  72. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/fused_moe/fused_moe.py +0 -0
  73. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/fused_moe/layer.py +0 -0
  74. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/fused_moe/patch.py +0 -0
  75. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/layernorm.py +0 -0
  76. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/linear.py +0 -0
  77. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/pooler.py +0 -0
  78. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/quantization/__init__.py +0 -0
  79. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/quantization/base_config.py +0 -0
  80. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/radix_attention.py +0 -0
  81. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/layers/torchao_utils.py +0 -0
  82. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/lora/lora.py +0 -0
  83. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/lora/lora_config.py +0 -0
  84. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/lora/lora_manager.py +0 -0
  85. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/image_processor.py +0 -0
  86. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/schedule_policy.py +0 -0
  87. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/managers/tp_worker.py +0 -0
  88. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/mem_cache/base_prefix_cache.py +0 -0
  89. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/mem_cache/chunk_cache.py +0 -0
  90. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/mem_cache/flush_cache.py +0 -0
  91. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/mem_cache/radix_cache.py +0 -0
  92. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/mm_utils.py +0 -0
  93. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/baichuan.py +0 -0
  94. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/chatglm.py +0 -0
  95. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/commandr.py +0 -0
  96. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/dbrx.py +0 -0
  97. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/deepseek.py +0 -0
  98. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/deepseek_v2.py +0 -0
  99. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/exaone.py +0 -0
  100. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/gemma.py +0 -0
  101. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/gemma2.py +0 -0
  102. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/gpt_bigcode.py +0 -0
  103. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/grok.py +0 -0
  104. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/internlm2.py +0 -0
  105. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/llama.py +0 -0
  106. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/llama_classification.py +0 -0
  107. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/llama_embedding.py +0 -0
  108. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/llama_reward.py +0 -0
  109. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/llava.py +0 -0
  110. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/llavavid.py +0 -0
  111. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/minicpm.py +0 -0
  112. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/minicpm3.py +0 -0
  113. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/mistral.py +0 -0
  114. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/mixtral.py +0 -0
  115. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/mixtral_quant.py +0 -0
  116. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/mllama.py +0 -0
  117. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/olmo.py +0 -0
  118. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/olmoe.py +0 -0
  119. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/qwen.py +0 -0
  120. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/qwen2.py +0 -0
  121. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/qwen2_moe.py +0 -0
  122. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/qwen2_vl.py +0 -0
  123. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/stablelm.py +0 -0
  124. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/torch_native_llama.py +0 -0
  125. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/xverse.py +0 -0
  126. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/xverse_moe.py +0 -0
  127. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/models/yivl.py +0 -0
  128. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/openai_api/adapter.py +0 -0
  129. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/openai_api/protocol.py +0 -0
  130. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/sampling/penaltylib/__init__.py +0 -0
  131. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/sampling/penaltylib/orchestrator.py +0 -0
  132. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -0
  133. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -0
  134. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -0
  135. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/sampling/sampling_batch_info.py +0 -0
  136. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/server_args.py +0 -0
  137. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/srt/utils.py +0 -0
  138. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/few_shot_gsm8k.py +0 -0
  139. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/few_shot_gsm8k_engine.py +0 -0
  140. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/runners.py +0 -0
  141. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/simple_eval_common.py +0 -0
  142. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/simple_eval_gpqa.py +0 -0
  143. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/simple_eval_humaneval.py +0 -0
  144. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/simple_eval_math.py +0 -0
  145. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/simple_eval_mgsm.py +0 -0
  146. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/simple_eval_mmlu.py +0 -0
  147. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/test_activation.py +0 -0
  148. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/test_layernorm.py +0 -0
  149. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/test/test_programs.py +0 -0
  150. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang/utils.py +0 -0
  151. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang.egg-info/SOURCES.txt +0 -0
  152. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang.egg-info/dependency_links.txt +0 -0
  153. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang.egg-info/requires.txt +0 -0
  154. {sglang-0.3.4.post1 → sglang-0.3.4.post2}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.3.4.post1
3
+ Version: 0.3.4.post2
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -328,23 +328,27 @@ You can install SGLang using any of the methods below.
328
328
  pip install --upgrade pip
329
329
  pip install "sglang[all]"
330
330
 
331
- # Install FlashInfer CUDA kernels
331
+ # Install FlashInfer accelerated kernels
332
332
  pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
333
333
  ```
334
334
 
335
+ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
336
+
335
337
  ### Method 2: From source
336
338
  ```
337
339
  # Use the last release branch
338
- git clone -b v0.3.4.post1 https://github.com/sgl-project/sglang.git
340
+ git clone -b v0.3.4.post2 https://github.com/sgl-project/sglang.git
339
341
  cd sglang
340
342
 
341
343
  pip install --upgrade pip
342
344
  pip install -e "python[all]"
343
345
 
344
- # Install FlashInfer CUDA kernels
346
+ # Install FlashInfer accelerated kernels
345
347
  pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
346
348
  ```
347
349
 
350
+ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
351
+
348
352
  ### Method 3: Using docker
349
353
  The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
350
354
  Replace `<secret>` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens).
@@ -498,7 +502,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
498
502
  ```
499
503
  python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
500
504
  ```
501
- - To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes.
505
+ - To enable the experimental overlapped scheduler, add `--enable-overlap-scheduler`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currenly.
506
+ - To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currenly.
502
507
  - To enable torchao quantization, add `--torchao-config int4wo-128`. It supports various quantization strategies.
503
508
  - To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
504
509
  - To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
@@ -519,7 +524,6 @@ We also provide an inference engine **without a HTTP server**. For example,
519
524
  ```python
520
525
  import sglang as sgl
521
526
 
522
-
523
527
  def main():
524
528
  prompts = [
525
529
  "Hello, my name is",
@@ -539,12 +543,8 @@ if __name__ == "__main__":
539
543
  main()
540
544
  ```
541
545
 
542
- This can be used for:
543
-
544
- 1. **Offline Batch Inference**
545
- 2. **Building Custom Servers**
546
-
547
- You can view the full example [here](https://github.com/sgl-project/sglang/tree/main/examples/runtime/engine)
546
+ This can be used for offline batch inference and building custom servers.
547
+ You can view the full example [here](https://github.com/sgl-project/sglang/tree/main/examples/runtime/engine).
548
548
 
549
549
  ### Supported Models
550
550
 
@@ -552,7 +552,7 @@ You can view the full example [here](https://github.com/sgl-project/sglang/tree/
552
552
  - Llama / Llama 2 / Llama 3 / Llama 3.1
553
553
  - Mistral / Mixtral / Mistral NeMo
554
554
  - Gemma / Gemma 2
555
- - Qwen / Qwen 2 / Qwen 2 MoE
555
+ - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL
556
556
  - DeepSeek / DeepSeek 2
557
557
  - OLMoE
558
558
  - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
@@ -712,7 +712,6 @@ print(state["answer_1"])
712
712
  ```
713
713
 
714
714
  #### More Examples
715
-
716
715
  Anthropic and VertexAI (Gemini) models are also supported.
717
716
  You can find more examples at [examples/quick_start](examples/frontend_language/quick_start).
718
717
 
@@ -56,23 +56,27 @@ You can install SGLang using any of the methods below.
56
56
  pip install --upgrade pip
57
57
  pip install "sglang[all]"
58
58
 
59
- # Install FlashInfer CUDA kernels
59
+ # Install FlashInfer accelerated kernels
60
60
  pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
61
61
  ```
62
62
 
63
+ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
64
+
63
65
  ### Method 2: From source
64
66
  ```
65
67
  # Use the last release branch
66
- git clone -b v0.3.4.post1 https://github.com/sgl-project/sglang.git
68
+ git clone -b v0.3.4.post2 https://github.com/sgl-project/sglang.git
67
69
  cd sglang
68
70
 
69
71
  pip install --upgrade pip
70
72
  pip install -e "python[all]"
71
73
 
72
- # Install FlashInfer CUDA kernels
74
+ # Install FlashInfer accelerated kernels
73
75
  pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
74
76
  ```
75
77
 
78
+ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
79
+
76
80
  ### Method 3: Using docker
77
81
  The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
78
82
  Replace `<secret>` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens).
@@ -226,7 +230,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
226
230
  ```
227
231
  python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
228
232
  ```
229
- - To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes.
233
+ - To enable the experimental overlapped scheduler, add `--enable-overlap-scheduler`. It overlaps CPU scheduler with GPU computation and can accelerate almost all workloads. This does not work for constrained decoding currenly.
234
+ - To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currenly.
230
235
  - To enable torchao quantization, add `--torchao-config int4wo-128`. It supports various quantization strategies.
231
236
  - To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
232
237
  - To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
@@ -247,7 +252,6 @@ We also provide an inference engine **without a HTTP server**. For example,
247
252
  ```python
248
253
  import sglang as sgl
249
254
 
250
-
251
255
  def main():
252
256
  prompts = [
253
257
  "Hello, my name is",
@@ -267,12 +271,8 @@ if __name__ == "__main__":
267
271
  main()
268
272
  ```
269
273
 
270
- This can be used for:
271
-
272
- 1. **Offline Batch Inference**
273
- 2. **Building Custom Servers**
274
-
275
- You can view the full example [here](https://github.com/sgl-project/sglang/tree/main/examples/runtime/engine)
274
+ This can be used for offline batch inference and building custom servers.
275
+ You can view the full example [here](https://github.com/sgl-project/sglang/tree/main/examples/runtime/engine).
276
276
 
277
277
  ### Supported Models
278
278
 
@@ -280,7 +280,7 @@ You can view the full example [here](https://github.com/sgl-project/sglang/tree/
280
280
  - Llama / Llama 2 / Llama 3 / Llama 3.1
281
281
  - Mistral / Mixtral / Mistral NeMo
282
282
  - Gemma / Gemma 2
283
- - Qwen / Qwen 2 / Qwen 2 MoE
283
+ - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL
284
284
  - DeepSeek / DeepSeek 2
285
285
  - OLMoE
286
286
  - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
@@ -440,7 +440,6 @@ print(state["answer_1"])
440
440
  ```
441
441
 
442
442
  #### More Examples
443
-
444
443
  Anthropic and VertexAI (Gemini) models are also supported.
445
444
  You can find more examples at [examples/quick_start](examples/frontend_language/quick_start).
446
445
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "sglang"
7
- version = "0.3.4.post1"
7
+ version = "0.3.4.post2"
8
8
  description = "SGLang is yet another fast serving framework for large language models and vision language models."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ import logging
17
+ import os
16
18
  from enum import IntEnum, auto
17
19
  from typing import Optional
18
20
 
@@ -20,6 +22,8 @@ from transformers import PretrainedConfig
20
22
 
21
23
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
22
24
 
25
+ logger = logging.getLogger(__name__)
26
+
23
27
 
24
28
  class AttentionArch(IntEnum):
25
29
  MLA = auto()
@@ -46,10 +50,29 @@ class ModelConfig:
46
50
  model_override_args=model_override_args,
47
51
  )
48
52
  self.hf_text_config = get_hf_text_config(self.hf_config)
53
+ derived_context_len = get_context_length(self.hf_text_config)
54
+ allow_long_context = os.environ.get(
55
+ "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
56
+ )
57
+
49
58
  if context_length is not None:
50
- self.context_len = context_length
59
+ if context_length > derived_context_len:
60
+ if allow_long_context:
61
+ logger.warning(
62
+ f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
63
+ f"This may lead to incorrect model outputs or CUDA errors."
64
+ )
65
+ self.context_len = context_length
66
+ else:
67
+ raise ValueError(
68
+ f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
69
+ f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
70
+ f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
71
+ )
72
+ else:
73
+ self.context_len = context_length
51
74
  else:
52
- self.context_len = get_context_length(self.hf_text_config)
75
+ self.context_len = derived_context_len
53
76
 
54
77
  # Unify the config keys for hf_text_config
55
78
  self.head_dim = getattr(
@@ -73,9 +73,16 @@ class FSMCache(BaseToolCache):
73
73
  def init_value(self, key):
74
74
  key_type, key_string = key
75
75
  if key_type == "json":
76
- regex = build_regex_from_schema(
77
- key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
78
- )
76
+ try:
77
+ regex = build_regex_from_schema(
78
+ key_string,
79
+ whitespace_pattern=self.constrained_json_whitespace_pattern,
80
+ )
81
+ except NotImplementedError as e:
82
+ logger.warning(
83
+ f"skip invalid json schema: json_schema={key_string}, {e=}"
84
+ )
85
+ return None, key_string
79
86
  elif key_type == "regex":
80
87
  regex = key_string
81
88
  else:
@@ -163,6 +163,8 @@ def get_tokenizer(
163
163
  "Using a slow tokenizer. This might cause a significant "
164
164
  "slowdown. Consider using a fast tokenizer instead."
165
165
  )
166
+
167
+ attach_additional_stop_token_ids(tokenizer)
166
168
  return tokenizer
167
169
 
168
170
 
@@ -181,4 +183,16 @@ def get_processor(
181
183
  tokenizer_revision=tokenizer_revision,
182
184
  **kwargs,
183
185
  )
186
+
187
+ attach_additional_stop_token_ids(processor.tokenizer)
184
188
  return processor
189
+
190
+
191
+ def attach_additional_stop_token_ids(tokenizer):
192
+ # Special handling for stop token <|eom_id|> generated by llama 3 tool use.
193
+ if "<|eom_id|>" in tokenizer.get_added_vocab():
194
+ tokenizer.additional_stop_token_ids = set(
195
+ [tokenizer.get_added_vocab()["<|eom_id|>"]]
196
+ )
197
+ else:
198
+ tokenizer.additional_stop_token_ids = None
@@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode:
337
337
  def update(
338
338
  self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
339
339
  ):
340
- # Keep the signature for type checking, will be initialized during runtime
340
+ # Keep the signature for type checking. It will be assigned during runtime.
341
341
  raise NotImplementedError()
342
342
 
343
343
  def update_single_wrapper(
@@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode:
432
432
  kv_start_idx,
433
433
  ):
434
434
  bs = len(req_pool_indices)
435
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
435
436
  kv_indptr = kv_indptr[: bs + 1]
436
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
437
437
  kv_indices = torch.empty(
438
438
  paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
439
439
  )
@@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill:
497
497
  self.update = self.update_single_wrapper
498
498
 
499
499
  def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
500
- # Keep the signature for type checking, will be initialized during runtime
500
+ # Keep the signature for type checking. It will be assigned during runtime.
501
501
  raise NotImplementedError()
502
502
 
503
503
  def update_single_wrapper(
@@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
589
589
  use_ragged,
590
590
  ):
591
591
  bs = len(req_pool_indices)
592
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
592
593
  kv_indptr = kv_indptr[: bs + 1]
593
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
594
594
  kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
595
595
  create_flashinfer_kv_indices_triton[(bs,)](
596
596
  self.req_to_token,
@@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
602
602
  self.max_context_len,
603
603
  )
604
604
 
605
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
605
606
  qo_indptr = qo_indptr[: bs + 1]
606
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
607
607
 
608
608
  # extend part
609
609
  if use_ragged:
@@ -33,17 +33,17 @@ class LogitsProcessorOutput:
33
33
  # The logits of the next tokens. shape: [#seq, vocab_size]
34
34
  next_token_logits: torch.Tensor
35
35
  # The logprobs of the next tokens. shape: [#seq, vocab_size]
36
- next_token_logprobs: torch.Tensor
36
+ next_token_logprobs: torch.Tensor = None
37
37
 
38
38
  # The normlaized logprobs of prompts. shape: [#seq]
39
- normalized_prompt_logprobs: torch.Tensor
39
+ normalized_prompt_logprobs: torch.Tensor = None
40
40
  # The logprobs of input tokens. shape: [#token, vocab_size]
41
- input_token_logprobs: torch.Tensor
41
+ input_token_logprobs: torch.Tensor = None
42
42
 
43
43
  # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
44
- input_top_logprobs: List
44
+ input_top_logprobs: List = None
45
45
  # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
46
- output_top_logprobs: List
46
+ output_top_logprobs: List = None
47
47
 
48
48
 
49
49
  @dataclasses.dataclass
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
22
22
 
23
23
  @staticmethod
24
24
  def get_input_positions(
25
- input_tokens: List[int],
25
+ input_tokens: torch.Tensor,
26
26
  image_grid_thw: Union[List[List[int]], torch.Tensor],
27
- video_grid_thw: Union[List[List[int]], torch.Tensor],
28
- image_token_id: int,
29
- video_token_id: int,
30
27
  vision_start_token_id: int,
31
- vision_end_token_id: int,
32
28
  spatial_merge_size: int,
33
29
  context_len: int = 0,
34
- extend_prefix_len: int = 0,
35
30
  ) -> Tuple[List[List[int]], int]:
36
31
  """Get mrope input positions and delta value."""
37
32
 
38
33
  if isinstance(image_grid_thw, torch.Tensor):
39
34
  image_grid_thw = image_grid_thw.tolist()
40
- if isinstance(video_grid_thw, torch.Tensor):
41
- video_grid_thw = video_grid_thw.tolist()
42
35
 
43
- input_tokens_tensor = torch.tensor(input_tokens)
44
36
  vision_start_indices = torch.argwhere(
45
- input_tokens_tensor == vision_start_token_id
37
+ input_tokens == vision_start_token_id
46
38
  ).squeeze(1)
47
- vision_tokens = input_tokens_tensor[vision_start_indices + 1]
48
- image_nums = (vision_tokens == image_token_id).sum()
49
- video_nums = (vision_tokens == video_token_id).sum()
39
+ image_indices = vision_start_indices + 1
40
+ image_nums = image_indices.shape[0]
50
41
  llm_pos_ids_list: list = []
51
42
 
52
43
  st = 0
53
- remain_images, remain_videos = image_nums, video_nums
54
-
55
- image_index, video_index = 0, 0
56
- for _ in range(image_nums + video_nums):
57
- if image_token_id in input_tokens and remain_images > 0:
58
- ed_image = input_tokens.index(image_token_id, st)
59
- else:
60
- ed_image = len(input_tokens) + 1
61
- if video_token_id in input_tokens and remain_videos > 0:
62
- ed_video = input_tokens.index(video_token_id, st)
63
- else:
64
- ed_video = len(input_tokens) + 1
65
- if ed_image < ed_video:
66
- t, h, w = (
67
- image_grid_thw[image_index][0],
68
- image_grid_thw[image_index][1],
69
- image_grid_thw[image_index][2],
70
- )
71
- image_index += 1
72
- remain_images -= 1
73
- ed = ed_image
74
- else:
75
- t, h, w = (
76
- video_grid_thw[video_index][0],
77
- video_grid_thw[video_index][1],
78
- video_grid_thw[video_index][2],
79
- )
80
- video_index += 1
81
- remain_videos -= 1
82
- ed = ed_video
44
+ input_tokens_len = input_tokens.shape[0]
45
+ for image_index in range(image_nums):
46
+ ed = image_indices[image_index].item()
47
+ t, h, w = (
48
+ image_grid_thw[image_index][0],
49
+ image_grid_thw[image_index][1],
50
+ image_grid_thw[image_index][2],
51
+ )
83
52
  llm_grid_t, llm_grid_h, llm_grid_w = (
84
53
  t,
85
54
  h // spatial_merge_size,
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
115
84
  )
116
85
  st = ed + llm_grid_t * llm_grid_h * llm_grid_w
117
86
 
118
- if st < len(input_tokens):
87
+ if st < input_tokens_len:
119
88
  st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
120
- text_len = len(input_tokens) - st
89
+ text_len = input_tokens_len - st
121
90
  llm_pos_ids_list.append(
122
91
  torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
123
92
  )
124
93
 
125
94
  llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
126
95
  llm_positions = llm_positions[:, context_len:]
127
- mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
128
- llm_positions += extend_prefix_len
129
-
96
+ mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
130
97
  return llm_positions.tolist(), mrope_position_delta
131
98
 
132
99
  @staticmethod
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import os
2
3
  from typing import Union
3
4
 
4
5
  import torch
@@ -17,6 +18,11 @@ if is_flashinfer_available():
17
18
  top_p_renorm_prob,
18
19
  )
19
20
 
21
+
22
+ # Crash on warning if we are running CI tests
23
+ crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
24
+
25
+
20
26
  logger = logging.getLogger(__name__)
21
27
 
22
28
 
@@ -33,56 +39,62 @@ class Sampler(nn.Module):
33
39
  if isinstance(logits, LogitsProcessorOutput):
34
40
  logits = logits.next_token_logits
35
41
 
36
- # Post process logits
37
42
  logits = logits.contiguous()
38
- logits.div_(sampling_info.temperatures)
39
- probs = torch.softmax(logits, dim=-1)
40
- logits = None
41
- del logits
42
-
43
- if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
44
- logger.warning("Detected errors during sampling! NaN in the probability.")
45
- probs = torch.where(
46
- torch.isnan(probs), torch.full_like(probs, 1e-10), probs
43
+
44
+ if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
45
+ logger.warning("Detected errors during sampling! NaN in the logits.")
46
+ logits = torch.where(
47
+ torch.isnan(logits), torch.full_like(logits, -1e5), logits
47
48
  )
49
+ exit(1) if crash_on_warning else None
48
50
 
49
51
  if sampling_info.is_all_greedy:
50
52
  # Use torch.argmax if all requests use greedy sampling
51
- batch_next_token_ids = torch.argmax(probs, -1)
52
- elif global_server_args_dict["sampling_backend"] == "flashinfer":
53
- max_top_k_round, batch_size = 32, probs.shape[0]
54
- uniform_samples = torch.rand(
55
- (max_top_k_round, batch_size), device=probs.device
56
- )
57
- if sampling_info.need_min_p_sampling:
58
- probs = top_k_renorm_prob(probs, sampling_info.top_ks)
59
- probs = top_p_renorm_prob(probs, sampling_info.top_ps)
60
- batch_next_token_ids, success = min_p_sampling_from_probs(
61
- probs, uniform_samples, sampling_info.min_ps
53
+ batch_next_token_ids = torch.argmax(logits, -1)
54
+ else:
55
+ # Post process logits
56
+ logits.div_(sampling_info.temperatures)
57
+ probs = torch.softmax(logits, dim=-1)
58
+ logits = None
59
+ del logits
60
+
61
+ if global_server_args_dict["sampling_backend"] == "flashinfer":
62
+ max_top_k_round, batch_size = 32, probs.shape[0]
63
+ uniform_samples = torch.rand(
64
+ (max_top_k_round, batch_size), device=probs.device
62
65
  )
63
- else:
64
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
66
+ if sampling_info.need_min_p_sampling:
67
+ probs = top_k_renorm_prob(probs, sampling_info.top_ks)
68
+ probs = top_p_renorm_prob(probs, sampling_info.top_ps)
69
+ batch_next_token_ids, success = min_p_sampling_from_probs(
70
+ probs, uniform_samples, sampling_info.min_ps
71
+ )
72
+ else:
73
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
74
+ probs,
75
+ uniform_samples,
76
+ sampling_info.top_ks,
77
+ sampling_info.top_ps,
78
+ filter_apply_order="joint",
79
+ )
80
+
81
+ if not torch.all(success):
82
+ logger.warning("Detected errors during sampling!")
83
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
84
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
85
+ # A slower fallback implementation with torch native operations.
86
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
65
87
  probs,
66
- uniform_samples,
67
88
  sampling_info.top_ks,
68
89
  sampling_info.top_ps,
69
- filter_apply_order="joint",
90
+ sampling_info.min_ps,
91
+ )
92
+ else:
93
+ raise ValueError(
94
+ f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
70
95
  )
71
96
 
72
- if not torch.all(success):
73
- logger.warning("Detected errors during sampling!")
74
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
75
- elif global_server_args_dict["sampling_backend"] == "pytorch":
76
- # Here we provide a slower fallback implementation.
77
- batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
78
- probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
79
- )
80
- else:
81
- raise ValueError(
82
- f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
83
- )
84
-
85
- return batch_next_token_ids
97
+ return batch_next_token_ids.to(torch.int32)
86
98
 
87
99
 
88
100
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -156,7 +156,7 @@ class DataParallelController:
156
156
  else:
157
157
  # Send other control messages to all workers
158
158
  for worker in self.workers:
159
- worker.queue.put(recv_req)
159
+ worker.send_pyobj(recv_req)
160
160
 
161
161
 
162
162
  def run_data_parallel_controller_process(
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
27
27
  BatchEmbeddingOut,
28
28
  BatchStrOut,
29
29
  BatchTokenIDOut,
30
+ GetMemPoolSizeReqOutput,
30
31
  UpdateWeightReqOutput,
31
32
  )
32
33
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
@@ -111,6 +112,9 @@ class DetokenizerManager:
111
112
  # If it is a weight update request, no detokenization is needed.
112
113
  self.send_to_tokenizer.send_pyobj(recv_obj)
113
114
  continue
115
+ elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
116
+ self.send_to_tokenizer.send_pyobj(recv_obj)
117
+ continue
114
118
  elif self.tokenizer is None:
115
119
  # If the tokenizer is skipped, no detokenization is needed
116
120
  self.send_to_tokenizer.send_pyobj(recv_obj)
@@ -353,3 +353,13 @@ class AbortReq:
353
353
  class ProfileReq(Enum):
354
354
  START_PROFILE = 1
355
355
  STOP_PROFILE = 2
356
+
357
+
358
+ @dataclass
359
+ class GetMemPoolSizeReq:
360
+ pass
361
+
362
+
363
+ @dataclass
364
+ class GetMemPoolSizeReqOutput:
365
+ size: int
@@ -334,15 +334,20 @@ class Req:
334
334
 
335
335
  last_token_id = self.output_ids[-1]
336
336
 
337
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
337
+ matched_eos = False
338
338
 
339
+ # Check stop token ids
340
+ if self.sampling_params.stop_token_ids:
341
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
339
342
  if self.tokenizer is not None:
340
343
  matched_eos |= last_token_id == self.tokenizer.eos_token_id
341
-
344
+ if self.tokenizer.additional_stop_token_ids:
345
+ matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
342
346
  if matched_eos and not self.sampling_params.ignore_eos:
343
347
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
344
348
  return
345
349
 
350
+ # Check stop strings
346
351
  if len(self.sampling_params.stop_strs) > 0:
347
352
  tail_str = self.tokenizer.decode(
348
353
  self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
@@ -514,7 +519,12 @@ class ScheduleBatch:
514
519
  out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
515
520
 
516
521
  if out_cache_loc is None:
517
- logger.error("Prefill out of memory. Try to lower your batch size.")
522
+ phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
523
+ logger.error(
524
+ f"{phase_str} out of memory. Try to lower your batch size.\n"
525
+ f"Try to allocate {num_tokens} tokens.\n"
526
+ f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
527
+ )
518
528
  if self.tree_cache is not None:
519
529
  self.tree_cache.pretty_print()
520
530
  exit(1)