sglang 0.3.5.post1__tar.gz → 0.3.5.post2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/PKG-INFO +2 -2
  2. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/pyproject.toml +2 -2
  3. sglang-0.3.5.post2/sglang/bench_offline_throughput.py +309 -0
  4. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/bench_serving.py +44 -30
  5. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/base_grammar_backend.py +4 -3
  6. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/outlines_backend.py +24 -24
  7. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/xgrammar_backend.py +40 -4
  8. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/fused_moe/patch.py +4 -2
  9. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/detokenizer_manager.py +0 -14
  10. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/scheduler.py +6 -2
  11. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/model_executor/model_runner.py +4 -1
  12. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/openai_api/adapter.py +5 -2
  13. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/openai_api/protocol.py +29 -26
  14. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/server.py +2 -1
  15. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/server_args.py +24 -3
  16. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/utils.py +33 -0
  17. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/test_utils.py +4 -4
  18. sglang-0.3.5.post2/sglang/version.py +1 -0
  19. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/PKG-INFO +2 -2
  20. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/SOURCES.txt +1 -0
  21. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/requires.txt +1 -1
  22. sglang-0.3.5.post1/sglang/version.py +0 -1
  23. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/LICENSE +0 -0
  24. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/README.md +0 -0
  25. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/setup.cfg +0 -0
  26. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/__init__.py +0 -0
  27. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/api.py +0 -0
  28. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/bench_latency.py +0 -0
  29. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/bench_server_latency.py +0 -0
  30. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/check_env.py +0 -0
  31. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/global_config.py +0 -0
  32. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/__init__.py +0 -0
  33. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/__init__.py +0 -0
  34. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/anthropic.py +0 -0
  35. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/base_backend.py +0 -0
  36. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/litellm.py +0 -0
  37. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/openai.py +0 -0
  38. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/runtime_endpoint.py +0 -0
  39. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/vertexai.py +0 -0
  40. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/chat_template.py +0 -0
  41. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/choices.py +0 -0
  42. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/compiler.py +0 -0
  43. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/interpreter.py +0 -0
  44. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/ir.py +0 -0
  45. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/tracer.py +0 -0
  46. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/launch_server.py +0 -0
  47. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/launch_server_llavavid.py +0 -0
  48. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/configs/__init__.py +0 -0
  49. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/configs/exaone.py +0 -0
  50. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/configs/model_config.py +0 -0
  51. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/configs/qwen2vl.py +0 -0
  52. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/__init__.py +0 -0
  53. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/outlines_jump_forward.py +0 -0
  54. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/conversation.py +0 -0
  55. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/hf_transformers_utils.py +0 -0
  56. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/activation.py +0 -0
  57. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/__init__.py +0 -0
  58. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/double_sparsity_backend.py +0 -0
  59. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/flashinfer_backend.py +0 -0
  60. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_backend.py +0 -0
  61. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_ops/decode_attention.py +0 -0
  62. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +0 -0
  63. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_ops/extend_attention.py +0 -0
  64. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_ops/prefill_attention.py +0 -0
  65. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/fused_moe/__init__.py +0 -0
  66. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/fused_moe/fused_moe.py +0 -0
  67. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/fused_moe/layer.py +0 -0
  68. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/layernorm.py +0 -0
  69. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/linear.py +0 -0
  70. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/logits_processor.py +0 -0
  71. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/pooler.py +0 -0
  72. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/quantization/__init__.py +0 -0
  73. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/quantization/base_config.py +0 -0
  74. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/radix_attention.py +0 -0
  75. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/rotary_embedding.py +0 -0
  76. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/sampler.py +0 -0
  77. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/torchao_utils.py +0 -0
  78. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/vocab_parallel_embedding.py +0 -0
  79. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/lora/lora.py +0 -0
  80. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/lora/lora_config.py +0 -0
  81. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/lora/lora_manager.py +0 -0
  82. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/data_parallel_controller.py +0 -0
  83. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/image_processor.py +0 -0
  84. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/io_struct.py +0 -0
  85. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/schedule_batch.py +0 -0
  86. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/schedule_policy.py +0 -0
  87. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/tokenizer_manager.py +0 -0
  88. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/tp_worker.py +0 -0
  89. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/tp_worker_overlap_thread.py +0 -0
  90. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/base_prefix_cache.py +0 -0
  91. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/chunk_cache.py +0 -0
  92. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/flush_cache.py +0 -0
  93. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/memory_pool.py +0 -0
  94. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/radix_cache.py +0 -0
  95. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/metrics/collector.py +0 -0
  96. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/metrics/func_timer.py +0 -0
  97. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mm_utils.py +0 -0
  98. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/model_executor/cuda_graph_runner.py +0 -0
  99. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/model_executor/forward_batch_info.py +0 -0
  100. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/baichuan.py +0 -0
  101. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/chatglm.py +0 -0
  102. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/commandr.py +0 -0
  103. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/dbrx.py +0 -0
  104. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/deepseek.py +0 -0
  105. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/deepseek_v2.py +0 -0
  106. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/exaone.py +0 -0
  107. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gemma.py +0 -0
  108. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gemma2.py +0 -0
  109. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gemma2_reward.py +0 -0
  110. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gpt2.py +0 -0
  111. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gpt_bigcode.py +0 -0
  112. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/grok.py +0 -0
  113. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/internlm2.py +0 -0
  114. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/internlm2_reward.py +0 -0
  115. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llama.py +0 -0
  116. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llama_classification.py +0 -0
  117. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llama_embedding.py +0 -0
  118. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llama_reward.py +0 -0
  119. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llava.py +0 -0
  120. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llavavid.py +0 -0
  121. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/minicpm.py +0 -0
  122. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/minicpm3.py +0 -0
  123. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/mistral.py +0 -0
  124. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/mixtral.py +0 -0
  125. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/mixtral_quant.py +0 -0
  126. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/mllama.py +0 -0
  127. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/olmo.py +0 -0
  128. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/olmoe.py +0 -0
  129. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/qwen.py +0 -0
  130. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/qwen2.py +0 -0
  131. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/qwen2_moe.py +0 -0
  132. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/qwen2_vl.py +0 -0
  133. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/stablelm.py +0 -0
  134. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/torch_native_llama.py +0 -0
  135. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/xverse.py +0 -0
  136. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/xverse_moe.py +0 -0
  137. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/yivl.py +0 -0
  138. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/__init__.py +0 -0
  139. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/orchestrator.py +0 -0
  140. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -0
  141. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +0 -0
  142. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -0
  143. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -0
  144. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/sampling_batch_info.py +0 -0
  145. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/sampling_params.py +2 -2
  146. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/few_shot_gsm8k.py +0 -0
  147. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/few_shot_gsm8k_engine.py +0 -0
  148. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/run_eval.py +0 -0
  149. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/runners.py +0 -0
  150. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_common.py +0 -0
  151. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_gpqa.py +0 -0
  152. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_humaneval.py +0 -0
  153. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_math.py +0 -0
  154. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_mgsm.py +0 -0
  155. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_mmlu.py +0 -0
  156. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/srt/sampling/penaltylib/utils.py +0 -0
  157. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/test_activation.py +0 -0
  158. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/test_layernorm.py +0 -0
  159. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/test_programs.py +0 -0
  160. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/utils.py +0 -0
  161. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/dependency_links.txt +0 -0
  162. {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.3.5.post1
3
+ Version: 0.3.5.post2
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -233,7 +233,7 @@ Requires-Dist: torchao; extra == "runtime-common"
233
233
  Requires-Dist: uvicorn; extra == "runtime-common"
234
234
  Requires-Dist: uvloop; extra == "runtime-common"
235
235
  Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
236
- Requires-Dist: outlines>=0.0.44; extra == "runtime-common"
236
+ Requires-Dist: outlines<0.1.0,>=0.0.44; extra == "runtime-common"
237
237
  Requires-Dist: modelscope; extra == "runtime-common"
238
238
  Provides-Extra: srt
239
239
  Requires-Dist: sglang[runtime_common]; extra == "srt"
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "sglang"
7
- version = "0.3.5.post1"
7
+ version = "0.3.5.post2"
8
8
  description = "SGLang is yet another fast serving framework for large language models and vision language models."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -19,7 +19,7 @@ dependencies = ["requests", "tqdm", "numpy", "IPython"]
19
19
  runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
20
20
  "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
21
21
  "torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
22
- "outlines>=0.0.44", "modelscope"]
22
+ "outlines>=0.0.44,<0.1.0", "modelscope"]
23
23
  srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
24
24
 
25
25
  # HIP (Heterogeneous-computing Interface for Portability) for AMD
@@ -0,0 +1,309 @@
1
+ """
2
+ Benchmark the throughput of using the offline LLM engine.
3
+ This script does not launch a server.
4
+ It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
5
+
6
+ # Usage
7
+ ## Sharegpt dataset with default args
8
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct
9
+
10
+ ## Random dataset with default args
11
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random
12
+
13
+ ## Shared prefix dataset with default args
14
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name generated-shared-prefix
15
+
16
+ ## Sharegpt dataset on runtime backend
17
+ python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --backend runtime
18
+ """
19
+
20
+ import argparse
21
+ import dataclasses
22
+ import json
23
+ import logging
24
+ import random
25
+ import time
26
+ from typing import List, Optional, Tuple
27
+
28
+ import numpy as np
29
+
30
+ from sglang.api import Engine
31
+ from sglang.bench_serving import (
32
+ get_dataset,
33
+ get_tokenizer,
34
+ sample_random_requests,
35
+ set_ulimit,
36
+ )
37
+ from sglang.srt.server import Runtime
38
+ from sglang.srt.server_args import ServerArgs
39
+
40
+
41
+ @dataclasses.dataclass
42
+ class BenchArgs:
43
+ backend: str = "engine"
44
+ result_filename: str = ""
45
+ dataset_name: str = "sharegpt"
46
+ dataset_path: str = ""
47
+ num_prompts: int = 1000
48
+ sharegpt_output_len: Optional[int] = None
49
+ random_input_len: int = 1024
50
+ random_output_len: int = 1024
51
+ random_range_ratio: float = 0.0
52
+ gen_num_groups: int = 64
53
+ gen_prompts_per_group: int = 16
54
+ gen_system_prompt_len: int = 2048
55
+ gen_question_len: int = 128
56
+ gen_output_len: int = 256
57
+ disable_ignore_eos: bool = False
58
+ seed: int = 1
59
+
60
+ @staticmethod
61
+ def add_cli_args(parser: argparse.ArgumentParser):
62
+ parser.add_argument("--backend", type=str, default=BenchArgs.backend)
63
+ parser.add_argument(
64
+ "--result-filename", type=str, default=BenchArgs.result_filename
65
+ )
66
+ parser.add_argument(
67
+ "--dataset-name",
68
+ type=str,
69
+ default="sharegpt",
70
+ choices=["sharegpt", "random", "generated-shared-prefix"],
71
+ help="Name of the dataset to benchmark on.",
72
+ )
73
+ parser.add_argument(
74
+ "--dataset-path", type=str, default="", help="Path to the dataset."
75
+ )
76
+ parser.add_argument(
77
+ "--num-prompts",
78
+ type=int,
79
+ default=BenchArgs.num_prompts,
80
+ help="Number of prompts to process. Default is 1000.",
81
+ )
82
+ parser.add_argument(
83
+ "--sharegpt-output-len",
84
+ type=int,
85
+ default=BenchArgs.sharegpt_output_len,
86
+ help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
87
+ )
88
+ parser.add_argument(
89
+ "--random-input-len",
90
+ type=int,
91
+ default=BenchArgs.random_input_len,
92
+ help="Number of input tokens per request, used only for random dataset.",
93
+ )
94
+ parser.add_argument(
95
+ "--random-output-len",
96
+ type=int,
97
+ default=BenchArgs.random_output_len,
98
+ help="Number of output tokens per request, used only for random dataset.",
99
+ )
100
+ parser.add_argument(
101
+ "--random-range-ratio",
102
+ type=float,
103
+ default=BenchArgs.random_range_ratio,
104
+ help="Range of sampled ratio of input/output length, "
105
+ "used only for random dataset.",
106
+ )
107
+ parser.add_argument(
108
+ "--gen-num-groups",
109
+ type=int,
110
+ default=BenchArgs.gen_num_groups,
111
+ help="Number of groups with shared prefix, used"
112
+ "only for generate-shared-prefix",
113
+ )
114
+ parser.add_argument(
115
+ "--gen-prompts-per-group",
116
+ type=int,
117
+ default=BenchArgs.gen_prompts_per_group,
118
+ help="Number of prompts per group of shared prefix, used"
119
+ "only for generate-shared-prefix",
120
+ )
121
+ parser.add_argument(
122
+ "--gen-system-prompt-len",
123
+ type=int,
124
+ default=BenchArgs.gen_system_prompt_len,
125
+ help="System prompt length, used" "only for generate-shared-prefix",
126
+ )
127
+ parser.add_argument(
128
+ "--gen-question-len",
129
+ type=int,
130
+ default=BenchArgs.gen_question_len,
131
+ help="Question length, used" "only for generate-shared-prefix",
132
+ )
133
+ parser.add_argument(
134
+ "--gen-output-len",
135
+ type=int,
136
+ default=BenchArgs.gen_output_len,
137
+ help="Target length in tokens for outputs in generated-shared-prefix dataset",
138
+ )
139
+ parser.add_argument(
140
+ "--disable-ignore-eos",
141
+ type=bool,
142
+ default=BenchArgs.disable_ignore_eos,
143
+ help="Disable ignore EOS token",
144
+ )
145
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
146
+
147
+ @classmethod
148
+ def from_cli_args(cls, args: argparse.Namespace):
149
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
150
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
151
+
152
+
153
+ def throughput_test_once(
154
+ backend_name: str,
155
+ backend,
156
+ reqs: List[Tuple[str, int, int]],
157
+ ignore_eos: bool,
158
+ ):
159
+ measurement_results = {
160
+ "backend": backend_name,
161
+ "successful_requests": len(reqs),
162
+ "total_latency": -1,
163
+ "total_input_tokens": sum(r[1] for r in reqs),
164
+ "total_output_tokens": -1,
165
+ "request_throughput": -1,
166
+ "input_throughput": -1,
167
+ "output_throughput": -1,
168
+ "total_throughput": -1,
169
+ }
170
+
171
+ prompt = [r[0] for r in reqs]
172
+ sampling_params = [
173
+ {
174
+ "temperature": 0,
175
+ "max_new_tokens": r[2],
176
+ "ignore_eos": ignore_eos,
177
+ }
178
+ for r in reqs
179
+ ]
180
+
181
+ st = time.perf_counter()
182
+ gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
183
+ latency = time.perf_counter() - st
184
+
185
+ if backend_name == "runtime":
186
+ gen_out = json.loads(gen_out)
187
+
188
+ measurement_results["total_latency"] = latency
189
+ measurement_results["total_output_tokens"] = sum(
190
+ o["meta_info"]["completion_tokens"] for o in gen_out
191
+ )
192
+ measurement_results["request_throughput"] = (
193
+ measurement_results["successful_requests"] / latency
194
+ )
195
+ measurement_results["input_throughput"] = (
196
+ measurement_results["total_input_tokens"] / latency
197
+ )
198
+ measurement_results["output_throughput"] = (
199
+ measurement_results["total_output_tokens"] / latency
200
+ )
201
+ measurement_results["total_throughput"] = (
202
+ measurement_results["total_input_tokens"]
203
+ + measurement_results["total_output_tokens"]
204
+ ) / latency
205
+
206
+ return measurement_results
207
+
208
+
209
+ def throughput_test(
210
+ server_args: ServerArgs,
211
+ bench_args: BenchArgs,
212
+ ):
213
+ if bench_args.backend == "engine":
214
+ backend = Engine(**dataclasses.asdict(server_args))
215
+ if not backend:
216
+ raise ValueError("Please provide valid engine arguments")
217
+ elif bench_args.backend == "runtime":
218
+ backend = Runtime(**dataclasses.asdict(server_args))
219
+ else:
220
+ raise ValueError('Please set backend to either "engine" or "runtime"')
221
+
222
+ tokenizer_id = server_args.model_path
223
+ tokenizer = get_tokenizer(tokenizer_id)
224
+
225
+ # Set global environmnets
226
+ set_ulimit()
227
+ random.seed(bench_args.seed)
228
+ np.random.seed(bench_args.seed)
229
+
230
+ # Read dataset
231
+ input_requests = get_dataset(bench_args, tokenizer)
232
+
233
+ warmup_requests = sample_random_requests(
234
+ input_len=20,
235
+ output_len=4,
236
+ num_prompts=2,
237
+ range_ratio=0.8,
238
+ tokenizer=tokenizer,
239
+ dataset_path=bench_args.dataset_path,
240
+ )
241
+
242
+ # Warm up
243
+ throughput_test_once(
244
+ backend_name=bench_args.backend,
245
+ backend=backend,
246
+ reqs=warmup_requests,
247
+ ignore_eos=not bench_args.disable_ignore_eos,
248
+ )
249
+
250
+ result = throughput_test_once(
251
+ backend_name=bench_args.backend,
252
+ backend=backend,
253
+ reqs=input_requests,
254
+ ignore_eos=not bench_args.disable_ignore_eos,
255
+ )
256
+
257
+ if bench_args.result_filename:
258
+ with open(bench_args.result_filename, "a") as fout:
259
+ fout.write(json.dumps(result) + "\n")
260
+
261
+ print(
262
+ "\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
263
+ )
264
+ print("{:<40} {:<10}".format("Backend:", result["backend"]))
265
+ print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
266
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
267
+ print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
268
+ print(
269
+ "{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
270
+ )
271
+ print(
272
+ "{:<40} {:<10.2f}".format(
273
+ "Request throughput (req/s):", result["request_throughput"]
274
+ )
275
+ )
276
+ print(
277
+ "{:<40} {:<10.2f}".format(
278
+ "Input token throughput (tok/s):", result["input_throughput"]
279
+ )
280
+ )
281
+ print(
282
+ "{:<40} {:<10.2f}".format(
283
+ "Output token throughput (tok/s):", result["output_throughput"]
284
+ )
285
+ )
286
+ print(
287
+ "{:<40} {:<10.2f}".format(
288
+ "Total token throughput (tok/s):", result["total_throughput"]
289
+ )
290
+ )
291
+ print("=" * 50)
292
+
293
+ return result
294
+
295
+
296
+ if __name__ == "__main__":
297
+ parser = argparse.ArgumentParser()
298
+ ServerArgs.add_cli_args(parser)
299
+ BenchArgs.add_cli_args(parser)
300
+ args = parser.parse_args()
301
+ server_args = ServerArgs.from_cli_args(args)
302
+ bench_args = BenchArgs.from_cli_args(args)
303
+
304
+ logging.basicConfig(
305
+ level=getattr(logging, server_args.log_level.upper()),
306
+ format="%(message)s",
307
+ )
308
+
309
+ throughput_test(server_args, bench_args)
@@ -421,6 +421,37 @@ def get_tokenizer(
421
421
  )
422
422
 
423
423
 
424
+ def get_dataset(args, tokenizer):
425
+ if args.dataset_name == "sharegpt":
426
+ input_requests = sample_sharegpt_requests(
427
+ dataset_path=args.dataset_path,
428
+ num_requests=args.num_prompts,
429
+ tokenizer=tokenizer,
430
+ fixed_output_len=args.sharegpt_output_len,
431
+ )
432
+ elif args.dataset_name == "random":
433
+ input_requests = sample_random_requests(
434
+ input_len=args.random_input_len,
435
+ output_len=args.random_output_len,
436
+ num_prompts=args.num_prompts,
437
+ range_ratio=args.random_range_ratio,
438
+ tokenizer=tokenizer,
439
+ dataset_path=args.dataset_path,
440
+ )
441
+ elif args.dataset_name == "generated-shared-prefix":
442
+ input_requests = sample_generated_shared_prefix_requests(
443
+ num_groups=args.gen_num_groups,
444
+ prompts_per_group=args.gen_prompts_per_group,
445
+ system_prompt_len=args.gen_system_prompt_len,
446
+ question_len=args.gen_question_len,
447
+ output_len=args.gen_output_len,
448
+ tokenizer=tokenizer,
449
+ )
450
+ else:
451
+ raise ValueError(f"Unknown dataset: {args.dataset_name}")
452
+ return input_requests
453
+
454
+
424
455
  ASYNC_REQUEST_FUNCS = {
425
456
  "sglang": async_request_sglang_generate,
426
457
  "sglang-native": async_request_sglang_generate,
@@ -443,6 +474,8 @@ class BenchmarkMetrics:
443
474
  input_throughput: float
444
475
  output_throughput: float
445
476
  output_throughput_retokenized: float
477
+ total_throughput: float
478
+ total_throughput_retokenized: float
446
479
  mean_ttft_ms: float
447
480
  median_ttft_ms: float
448
481
  std_ttft_ms: float
@@ -590,7 +623,6 @@ def sample_random_requests(
590
623
  (data["conversations"][0]["value"], data["conversations"][1]["value"])
591
624
  for data in dataset
592
625
  ]
593
-
594
626
  # Shuffle the dataset.
595
627
  random.shuffle(dataset)
596
628
 
@@ -764,6 +796,9 @@ def calculate_metrics(
764
796
  input_throughput=total_input / dur_s,
765
797
  output_throughput=sum(output_lens) / dur_s,
766
798
  output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
799
+ total_throughput=(total_input + sum(output_lens)) / dur_s,
800
+ total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
801
+ / dur_s,
767
802
  mean_ttft_ms=np.mean(ttfts or 0)
768
803
  * 1000, # ttfts is empty if streaming is not supported by backend
769
804
  median_ttft_ms=np.median(ttfts or 0) * 1000,
@@ -881,6 +916,11 @@ async def benchmark(
881
916
  "Output token throughput (tok/s):", metrics.output_throughput
882
917
  )
883
918
  )
919
+ print(
920
+ "{:<40} {:<10.2f}".format(
921
+ "Total token throughput (tok/s):", metrics.total_throughput
922
+ )
923
+ )
884
924
  print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
885
925
  print(
886
926
  "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
@@ -1098,35 +1138,7 @@ def run_benchmark(args_: argparse.Namespace):
1098
1138
 
1099
1139
  tokenizer = get_tokenizer(tokenizer_id)
1100
1140
 
1101
- if args.dataset_name == "sharegpt":
1102
- assert args.random_input_len is None and args.random_output_len is None
1103
- input_requests = sample_sharegpt_requests(
1104
- dataset_path=args.dataset_path,
1105
- num_requests=args.num_prompts,
1106
- tokenizer=tokenizer,
1107
- fixed_output_len=args.sharegpt_output_len,
1108
- )
1109
- elif args.dataset_name == "random":
1110
- assert args.random_input_len is not None and args.random_output_len is not None
1111
- input_requests = sample_random_requests(
1112
- input_len=args.random_input_len,
1113
- output_len=args.random_output_len,
1114
- num_prompts=args.num_prompts,
1115
- range_ratio=args.random_range_ratio,
1116
- tokenizer=tokenizer,
1117
- dataset_path=args.dataset_path,
1118
- )
1119
- elif args.dataset_name == "generated-shared-prefix":
1120
- input_requests = sample_generated_shared_prefix_requests(
1121
- num_groups=args.gen_num_groups,
1122
- prompts_per_group=args.gen_prompts_per_group,
1123
- system_prompt_len=args.gen_system_prompt_len,
1124
- question_len=args.gen_question_len,
1125
- output_len=args.gen_output_len,
1126
- tokenizer=tokenizer,
1127
- )
1128
- else:
1129
- raise ValueError(f"Unknown dataset: {args.dataset_name}")
1141
+ input_requests = get_dataset(args, tokenizer)
1130
1142
 
1131
1143
  if not args.multi:
1132
1144
  return asyncio.run(
@@ -1229,10 +1241,12 @@ if __name__ == "__main__":
1229
1241
  parser.add_argument(
1230
1242
  "--random-input-len",
1231
1243
  type=int,
1244
+ default=1024,
1232
1245
  help="Number of input tokens per request, used only for random dataset.",
1233
1246
  )
1234
1247
  parser.add_argument(
1235
1248
  "--random-output-len",
1249
+ default=1024,
1236
1250
  type=int,
1237
1251
  help="Number of output tokens per request, used only for random dataset.",
1238
1252
  )
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- """The baseclass of backends for grammar-guided constrained decoding."""
16
+ """The baseclass of a backend for grammar-guided constrained decoding."""
17
17
 
18
18
  from concurrent.futures import Future, ThreadPoolExecutor
19
19
  from dataclasses import dataclass
@@ -52,7 +52,7 @@ class BaseGrammarBackend:
52
52
  else:
53
53
  entry.value = self.init_value_impl(key)
54
54
  entry.event.set()
55
- return entry.value.copy()
55
+ return entry.value.copy() if entry.value else None
56
56
 
57
57
  def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
58
58
  raise NotImplementedError()
@@ -62,7 +62,8 @@ class BaseGrammarBackend:
62
62
  entry = self.cache.get(key)
63
63
  if not entry or not entry.event.is_set():
64
64
  return None
65
- return self.cache[key].value.copy()
65
+ val = self.cache[key].value
66
+ return val.copy() if val else None
66
67
 
67
68
  def get_future_value(self, key: Tuple[str, str]) -> Future:
68
69
  return self.executor.submit(self.init_value, key)
@@ -19,9 +19,12 @@ import json
19
19
  import logging
20
20
  from typing import Dict, List, Optional, Tuple, Union
21
21
 
22
+ import interegular
22
23
  import torch
23
24
  from outlines.fsm.guide import RegexGuide
25
+ from outlines.fsm.json_schema import build_regex_from_schema
24
26
  from outlines.models.transformers import TransformerTokenizer
27
+ from pydantic import BaseModel
25
28
 
26
29
  from sglang.srt.constrained.base_grammar_backend import (
27
30
  BaseGrammarBackend,
@@ -32,26 +35,6 @@ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
32
35
  logger = logging.getLogger(__name__)
33
36
 
34
37
 
35
- try:
36
- from outlines.fsm.json_schema import build_regex_from_object
37
- except ImportError:
38
- # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
39
- # which only accepts string schema as input.
40
- from outlines.fsm.json_schema import build_regex_from_schema
41
- from pydantic import BaseModel
42
-
43
- def build_regex_from_object(
44
- object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
45
- ):
46
- if isinstance(object, type(BaseModel)):
47
- schema = json.dumps(object.model_json_schema())
48
- elif isinstance(object, Dict):
49
- schema = json.dumps(object)
50
- else:
51
- schema = object
52
- return build_regex_from_schema(schema, whitespace_pattern)
53
-
54
-
55
38
  class OutlinesGrammar(BaseGrammarObject):
56
39
  def __init__(
57
40
  self,
@@ -147,19 +130,36 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
147
130
  key_string,
148
131
  whitespace_pattern=self.whitespace_pattern,
149
132
  )
150
- except NotImplementedError as e:
133
+ except (NotImplementedError, json.decoder.JSONDecodeError) as e:
151
134
  logger.warning(
152
- f"skip invalid json schema: json_schema={key_string}, {e=}"
135
+ f"Skip invalid json_schema: json_schema={key_string}, {e=}"
153
136
  )
154
- return None, key_string
137
+ return None
155
138
  elif key_type == "regex":
156
139
  regex = key_string
157
140
  else:
158
141
  raise ValueError(f"Invalid key_type: {key_type}")
159
142
 
160
- guide = RegexGuide(regex, self.outlines_tokenizer)
143
+ try:
144
+ guide = RegexGuide(regex, self.outlines_tokenizer)
145
+ except interegular.patterns.InvalidSyntax as e:
146
+ logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
147
+ return None
148
+
161
149
  if self.allow_jump_forward:
162
150
  jump_forward_map = OutlinesJumpForwardMap(regex)
163
151
  else:
164
152
  jump_forward_map = None
165
153
  return OutlinesGrammar(guide, jump_forward_map)
154
+
155
+
156
+ def build_regex_from_object(
157
+ object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
158
+ ):
159
+ if isinstance(object, type(BaseModel)):
160
+ schema = json.dumps(object.model_json_schema())
161
+ elif isinstance(object, Dict):
162
+ schema = json.dumps(object)
163
+ else:
164
+ schema = object
165
+ return build_regex_from_schema(schema, whitespace_pattern)
@@ -15,16 +15,29 @@ limitations under the License.
15
15
 
16
16
  """Constrained decoding with xgrammar backend."""
17
17
 
18
+ import logging
18
19
  from typing import List, Tuple
19
20
 
20
21
  import torch
21
- from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
22
+
23
+ try:
24
+ from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
25
+
26
+ import_error = None
27
+ except ImportError as e:
28
+ CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
29
+ ImportError
30
+ )
31
+ import_error = e
22
32
 
23
33
  from sglang.srt.constrained.base_grammar_backend import (
24
34
  BaseGrammarBackend,
25
35
  BaseGrammarObject,
26
36
  )
27
37
 
38
+ logger = logging.getLogger(__name__)
39
+
40
+
28
41
  MAX_ROLLBACK_TOKENS = 10
29
42
 
30
43
 
@@ -91,15 +104,37 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
91
104
  vocab_size: int,
92
105
  ):
93
106
  super().__init__()
107
+
108
+ if import_error:
109
+ logger.warning(
110
+ f"Ignore import error for the grammar backend: {import_error}"
111
+ )
112
+ self.grammar_cache = None
113
+ return
114
+
94
115
  self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
95
116
  self.vocab_size = vocab_size
96
117
 
97
118
  def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
119
+ if import_error:
120
+ raise import_error
121
+
98
122
  key_type, key_string = key
99
123
  if key_type == "json":
100
- ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
124
+ try:
125
+ ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(
126
+ key_string
127
+ )
128
+ except RuntimeError as e:
129
+ logging.warning(
130
+ f"Skip invalid json_schema: json_schema={key_string}, {e=}"
131
+ )
132
+ return None
101
133
  elif key_type == "regex":
102
- raise ValueError("regex hasn't been supported by xgrammar yet")
134
+ logger.warning(
135
+ "regex hasn't been supported by xgrammar yet. This is skipped."
136
+ )
137
+ return None
103
138
  else:
104
139
  raise ValueError(f"Invalid key_type: {key_type}")
105
140
 
@@ -111,4 +146,5 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
111
146
  return XGrammarGrammar(matcher, self.vocab_size, ctx)
112
147
 
113
148
  def reset(self):
114
- self.grammar_cache.clear()
149
+ if self.grammar_cache:
150
+ self.grammar_cache.clear()
@@ -1,4 +1,4 @@
1
- from typing import Optional
1
+ from typing import Callable, Optional
2
2
 
3
3
  import torch
4
4
  from torch.nn import functional as F
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
98
98
  renormalize: bool,
99
99
  topk_group: Optional[int] = None,
100
100
  num_expert_group: Optional[int] = None,
101
+ custom_routing_function: Optional[Callable] = None,
101
102
  ) -> torch.Tensor:
103
+ assert custom_routing_function is None
102
104
  topk_weights, topk_ids = select_experts_native(
103
105
  hidden_states=x,
104
106
  router_logits=router_logits,
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
114
116
  x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
115
117
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
116
118
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
117
- return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
119
+ return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))