sglang 0.3.5.post1__tar.gz → 0.3.5.post2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/PKG-INFO +2 -2
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/pyproject.toml +2 -2
- sglang-0.3.5.post2/sglang/bench_offline_throughput.py +309 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/bench_serving.py +44 -30
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/base_grammar_backend.py +4 -3
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/outlines_backend.py +24 -24
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/xgrammar_backend.py +40 -4
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/fused_moe/patch.py +4 -2
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/detokenizer_manager.py +0 -14
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/scheduler.py +6 -2
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/model_executor/model_runner.py +4 -1
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/openai_api/adapter.py +5 -2
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/openai_api/protocol.py +29 -26
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/server.py +2 -1
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/server_args.py +24 -3
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/utils.py +33 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/test_utils.py +4 -4
- sglang-0.3.5.post2/sglang/version.py +1 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/PKG-INFO +2 -2
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/SOURCES.txt +1 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/requires.txt +1 -1
- sglang-0.3.5.post1/sglang/version.py +0 -1
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/LICENSE +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/README.md +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/setup.cfg +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/__init__.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/api.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/bench_latency.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/bench_server_latency.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/check_env.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/global_config.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/__init__.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/__init__.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/anthropic.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/base_backend.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/litellm.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/openai.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/runtime_endpoint.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/backend/vertexai.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/chat_template.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/choices.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/compiler.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/interpreter.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/ir.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/lang/tracer.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/launch_server.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/launch_server_llavavid.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/configs/__init__.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/configs/exaone.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/configs/model_config.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/configs/qwen2vl.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/__init__.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/constrained/outlines_jump_forward.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/conversation.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/hf_transformers_utils.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/activation.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/__init__.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/double_sparsity_backend.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/flashinfer_backend.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_backend.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_ops/decode_attention.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_ops/extend_attention.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/attention/triton_ops/prefill_attention.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/fused_moe/__init__.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/fused_moe/fused_moe.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/fused_moe/layer.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/layernorm.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/linear.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/logits_processor.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/pooler.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/quantization/__init__.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/quantization/base_config.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/radix_attention.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/rotary_embedding.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/sampler.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/torchao_utils.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/layers/vocab_parallel_embedding.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/lora/lora.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/lora/lora_config.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/lora/lora_manager.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/data_parallel_controller.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/image_processor.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/io_struct.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/schedule_batch.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/schedule_policy.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/tokenizer_manager.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/tp_worker.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/managers/tp_worker_overlap_thread.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/base_prefix_cache.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/chunk_cache.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/flush_cache.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/memory_pool.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mem_cache/radix_cache.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/metrics/collector.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/metrics/func_timer.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/mm_utils.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/model_executor/cuda_graph_runner.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/model_executor/forward_batch_info.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/baichuan.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/chatglm.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/commandr.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/dbrx.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/deepseek.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/deepseek_v2.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/exaone.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gemma.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gemma2.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gemma2_reward.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gpt2.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/gpt_bigcode.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/grok.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/internlm2.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/internlm2_reward.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llama.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llama_classification.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llama_embedding.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llama_reward.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llava.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/llavavid.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/minicpm.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/minicpm3.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/mistral.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/mixtral.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/mixtral_quant.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/mllama.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/olmo.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/olmoe.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/qwen.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/qwen2.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/qwen2_moe.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/qwen2_vl.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/stablelm.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/torch_native_llama.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/xverse.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/xverse_moe.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/models/yivl.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/__init__.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/orchestrator.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/sampling_batch_info.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/srt/sampling/sampling_params.py +2 -2
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/few_shot_gsm8k.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/few_shot_gsm8k_engine.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/run_eval.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/runners.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_common.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_gpqa.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_humaneval.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_math.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_mgsm.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/simple_eval_mmlu.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/srt/sampling/penaltylib/utils.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/test_activation.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/test_layernorm.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/test/test_programs.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang/utils.py +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/dependency_links.txt +0 -0
- {sglang-0.3.5.post1 → sglang-0.3.5.post2}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.3.5.
|
3
|
+
Version: 0.3.5.post2
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -233,7 +233,7 @@ Requires-Dist: torchao; extra == "runtime-common"
|
|
233
233
|
Requires-Dist: uvicorn; extra == "runtime-common"
|
234
234
|
Requires-Dist: uvloop; extra == "runtime-common"
|
235
235
|
Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
|
236
|
-
Requires-Dist: outlines
|
236
|
+
Requires-Dist: outlines<0.1.0,>=0.0.44; extra == "runtime-common"
|
237
237
|
Requires-Dist: modelscope; extra == "runtime-common"
|
238
238
|
Provides-Extra: srt
|
239
239
|
Requires-Dist: sglang[runtime_common]; extra == "srt"
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "sglang"
|
7
|
-
version = "0.3.5.
|
7
|
+
version = "0.3.5.post2"
|
8
8
|
description = "SGLang is yet another fast serving framework for large language models and vision language models."
|
9
9
|
readme = "README.md"
|
10
10
|
requires-python = ">=3.8"
|
@@ -19,7 +19,7 @@ dependencies = ["requests", "tqdm", "numpy", "IPython"]
|
|
19
19
|
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
20
20
|
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
|
21
21
|
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
|
22
|
-
"outlines>=0.0.44", "modelscope"]
|
22
|
+
"outlines>=0.0.44,<0.1.0", "modelscope"]
|
23
23
|
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
|
24
24
|
|
25
25
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
@@ -0,0 +1,309 @@
|
|
1
|
+
"""
|
2
|
+
Benchmark the throughput of using the offline LLM engine.
|
3
|
+
This script does not launch a server.
|
4
|
+
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
|
5
|
+
|
6
|
+
# Usage
|
7
|
+
## Sharegpt dataset with default args
|
8
|
+
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct
|
9
|
+
|
10
|
+
## Random dataset with default args
|
11
|
+
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random
|
12
|
+
|
13
|
+
## Shared prefix dataset with default args
|
14
|
+
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name generated-shared-prefix
|
15
|
+
|
16
|
+
## Sharegpt dataset on runtime backend
|
17
|
+
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --backend runtime
|
18
|
+
"""
|
19
|
+
|
20
|
+
import argparse
|
21
|
+
import dataclasses
|
22
|
+
import json
|
23
|
+
import logging
|
24
|
+
import random
|
25
|
+
import time
|
26
|
+
from typing import List, Optional, Tuple
|
27
|
+
|
28
|
+
import numpy as np
|
29
|
+
|
30
|
+
from sglang.api import Engine
|
31
|
+
from sglang.bench_serving import (
|
32
|
+
get_dataset,
|
33
|
+
get_tokenizer,
|
34
|
+
sample_random_requests,
|
35
|
+
set_ulimit,
|
36
|
+
)
|
37
|
+
from sglang.srt.server import Runtime
|
38
|
+
from sglang.srt.server_args import ServerArgs
|
39
|
+
|
40
|
+
|
41
|
+
@dataclasses.dataclass
|
42
|
+
class BenchArgs:
|
43
|
+
backend: str = "engine"
|
44
|
+
result_filename: str = ""
|
45
|
+
dataset_name: str = "sharegpt"
|
46
|
+
dataset_path: str = ""
|
47
|
+
num_prompts: int = 1000
|
48
|
+
sharegpt_output_len: Optional[int] = None
|
49
|
+
random_input_len: int = 1024
|
50
|
+
random_output_len: int = 1024
|
51
|
+
random_range_ratio: float = 0.0
|
52
|
+
gen_num_groups: int = 64
|
53
|
+
gen_prompts_per_group: int = 16
|
54
|
+
gen_system_prompt_len: int = 2048
|
55
|
+
gen_question_len: int = 128
|
56
|
+
gen_output_len: int = 256
|
57
|
+
disable_ignore_eos: bool = False
|
58
|
+
seed: int = 1
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
62
|
+
parser.add_argument("--backend", type=str, default=BenchArgs.backend)
|
63
|
+
parser.add_argument(
|
64
|
+
"--result-filename", type=str, default=BenchArgs.result_filename
|
65
|
+
)
|
66
|
+
parser.add_argument(
|
67
|
+
"--dataset-name",
|
68
|
+
type=str,
|
69
|
+
default="sharegpt",
|
70
|
+
choices=["sharegpt", "random", "generated-shared-prefix"],
|
71
|
+
help="Name of the dataset to benchmark on.",
|
72
|
+
)
|
73
|
+
parser.add_argument(
|
74
|
+
"--dataset-path", type=str, default="", help="Path to the dataset."
|
75
|
+
)
|
76
|
+
parser.add_argument(
|
77
|
+
"--num-prompts",
|
78
|
+
type=int,
|
79
|
+
default=BenchArgs.num_prompts,
|
80
|
+
help="Number of prompts to process. Default is 1000.",
|
81
|
+
)
|
82
|
+
parser.add_argument(
|
83
|
+
"--sharegpt-output-len",
|
84
|
+
type=int,
|
85
|
+
default=BenchArgs.sharegpt_output_len,
|
86
|
+
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
87
|
+
)
|
88
|
+
parser.add_argument(
|
89
|
+
"--random-input-len",
|
90
|
+
type=int,
|
91
|
+
default=BenchArgs.random_input_len,
|
92
|
+
help="Number of input tokens per request, used only for random dataset.",
|
93
|
+
)
|
94
|
+
parser.add_argument(
|
95
|
+
"--random-output-len",
|
96
|
+
type=int,
|
97
|
+
default=BenchArgs.random_output_len,
|
98
|
+
help="Number of output tokens per request, used only for random dataset.",
|
99
|
+
)
|
100
|
+
parser.add_argument(
|
101
|
+
"--random-range-ratio",
|
102
|
+
type=float,
|
103
|
+
default=BenchArgs.random_range_ratio,
|
104
|
+
help="Range of sampled ratio of input/output length, "
|
105
|
+
"used only for random dataset.",
|
106
|
+
)
|
107
|
+
parser.add_argument(
|
108
|
+
"--gen-num-groups",
|
109
|
+
type=int,
|
110
|
+
default=BenchArgs.gen_num_groups,
|
111
|
+
help="Number of groups with shared prefix, used"
|
112
|
+
"only for generate-shared-prefix",
|
113
|
+
)
|
114
|
+
parser.add_argument(
|
115
|
+
"--gen-prompts-per-group",
|
116
|
+
type=int,
|
117
|
+
default=BenchArgs.gen_prompts_per_group,
|
118
|
+
help="Number of prompts per group of shared prefix, used"
|
119
|
+
"only for generate-shared-prefix",
|
120
|
+
)
|
121
|
+
parser.add_argument(
|
122
|
+
"--gen-system-prompt-len",
|
123
|
+
type=int,
|
124
|
+
default=BenchArgs.gen_system_prompt_len,
|
125
|
+
help="System prompt length, used" "only for generate-shared-prefix",
|
126
|
+
)
|
127
|
+
parser.add_argument(
|
128
|
+
"--gen-question-len",
|
129
|
+
type=int,
|
130
|
+
default=BenchArgs.gen_question_len,
|
131
|
+
help="Question length, used" "only for generate-shared-prefix",
|
132
|
+
)
|
133
|
+
parser.add_argument(
|
134
|
+
"--gen-output-len",
|
135
|
+
type=int,
|
136
|
+
default=BenchArgs.gen_output_len,
|
137
|
+
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
138
|
+
)
|
139
|
+
parser.add_argument(
|
140
|
+
"--disable-ignore-eos",
|
141
|
+
type=bool,
|
142
|
+
default=BenchArgs.disable_ignore_eos,
|
143
|
+
help="Disable ignore EOS token",
|
144
|
+
)
|
145
|
+
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
146
|
+
|
147
|
+
@classmethod
|
148
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
149
|
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
150
|
+
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
151
|
+
|
152
|
+
|
153
|
+
def throughput_test_once(
|
154
|
+
backend_name: str,
|
155
|
+
backend,
|
156
|
+
reqs: List[Tuple[str, int, int]],
|
157
|
+
ignore_eos: bool,
|
158
|
+
):
|
159
|
+
measurement_results = {
|
160
|
+
"backend": backend_name,
|
161
|
+
"successful_requests": len(reqs),
|
162
|
+
"total_latency": -1,
|
163
|
+
"total_input_tokens": sum(r[1] for r in reqs),
|
164
|
+
"total_output_tokens": -1,
|
165
|
+
"request_throughput": -1,
|
166
|
+
"input_throughput": -1,
|
167
|
+
"output_throughput": -1,
|
168
|
+
"total_throughput": -1,
|
169
|
+
}
|
170
|
+
|
171
|
+
prompt = [r[0] for r in reqs]
|
172
|
+
sampling_params = [
|
173
|
+
{
|
174
|
+
"temperature": 0,
|
175
|
+
"max_new_tokens": r[2],
|
176
|
+
"ignore_eos": ignore_eos,
|
177
|
+
}
|
178
|
+
for r in reqs
|
179
|
+
]
|
180
|
+
|
181
|
+
st = time.perf_counter()
|
182
|
+
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
|
183
|
+
latency = time.perf_counter() - st
|
184
|
+
|
185
|
+
if backend_name == "runtime":
|
186
|
+
gen_out = json.loads(gen_out)
|
187
|
+
|
188
|
+
measurement_results["total_latency"] = latency
|
189
|
+
measurement_results["total_output_tokens"] = sum(
|
190
|
+
o["meta_info"]["completion_tokens"] for o in gen_out
|
191
|
+
)
|
192
|
+
measurement_results["request_throughput"] = (
|
193
|
+
measurement_results["successful_requests"] / latency
|
194
|
+
)
|
195
|
+
measurement_results["input_throughput"] = (
|
196
|
+
measurement_results["total_input_tokens"] / latency
|
197
|
+
)
|
198
|
+
measurement_results["output_throughput"] = (
|
199
|
+
measurement_results["total_output_tokens"] / latency
|
200
|
+
)
|
201
|
+
measurement_results["total_throughput"] = (
|
202
|
+
measurement_results["total_input_tokens"]
|
203
|
+
+ measurement_results["total_output_tokens"]
|
204
|
+
) / latency
|
205
|
+
|
206
|
+
return measurement_results
|
207
|
+
|
208
|
+
|
209
|
+
def throughput_test(
|
210
|
+
server_args: ServerArgs,
|
211
|
+
bench_args: BenchArgs,
|
212
|
+
):
|
213
|
+
if bench_args.backend == "engine":
|
214
|
+
backend = Engine(**dataclasses.asdict(server_args))
|
215
|
+
if not backend:
|
216
|
+
raise ValueError("Please provide valid engine arguments")
|
217
|
+
elif bench_args.backend == "runtime":
|
218
|
+
backend = Runtime(**dataclasses.asdict(server_args))
|
219
|
+
else:
|
220
|
+
raise ValueError('Please set backend to either "engine" or "runtime"')
|
221
|
+
|
222
|
+
tokenizer_id = server_args.model_path
|
223
|
+
tokenizer = get_tokenizer(tokenizer_id)
|
224
|
+
|
225
|
+
# Set global environmnets
|
226
|
+
set_ulimit()
|
227
|
+
random.seed(bench_args.seed)
|
228
|
+
np.random.seed(bench_args.seed)
|
229
|
+
|
230
|
+
# Read dataset
|
231
|
+
input_requests = get_dataset(bench_args, tokenizer)
|
232
|
+
|
233
|
+
warmup_requests = sample_random_requests(
|
234
|
+
input_len=20,
|
235
|
+
output_len=4,
|
236
|
+
num_prompts=2,
|
237
|
+
range_ratio=0.8,
|
238
|
+
tokenizer=tokenizer,
|
239
|
+
dataset_path=bench_args.dataset_path,
|
240
|
+
)
|
241
|
+
|
242
|
+
# Warm up
|
243
|
+
throughput_test_once(
|
244
|
+
backend_name=bench_args.backend,
|
245
|
+
backend=backend,
|
246
|
+
reqs=warmup_requests,
|
247
|
+
ignore_eos=not bench_args.disable_ignore_eos,
|
248
|
+
)
|
249
|
+
|
250
|
+
result = throughput_test_once(
|
251
|
+
backend_name=bench_args.backend,
|
252
|
+
backend=backend,
|
253
|
+
reqs=input_requests,
|
254
|
+
ignore_eos=not bench_args.disable_ignore_eos,
|
255
|
+
)
|
256
|
+
|
257
|
+
if bench_args.result_filename:
|
258
|
+
with open(bench_args.result_filename, "a") as fout:
|
259
|
+
fout.write(json.dumps(result) + "\n")
|
260
|
+
|
261
|
+
print(
|
262
|
+
"\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
|
263
|
+
)
|
264
|
+
print("{:<40} {:<10}".format("Backend:", result["backend"]))
|
265
|
+
print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
|
266
|
+
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
|
267
|
+
print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
|
268
|
+
print(
|
269
|
+
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
270
|
+
)
|
271
|
+
print(
|
272
|
+
"{:<40} {:<10.2f}".format(
|
273
|
+
"Request throughput (req/s):", result["request_throughput"]
|
274
|
+
)
|
275
|
+
)
|
276
|
+
print(
|
277
|
+
"{:<40} {:<10.2f}".format(
|
278
|
+
"Input token throughput (tok/s):", result["input_throughput"]
|
279
|
+
)
|
280
|
+
)
|
281
|
+
print(
|
282
|
+
"{:<40} {:<10.2f}".format(
|
283
|
+
"Output token throughput (tok/s):", result["output_throughput"]
|
284
|
+
)
|
285
|
+
)
|
286
|
+
print(
|
287
|
+
"{:<40} {:<10.2f}".format(
|
288
|
+
"Total token throughput (tok/s):", result["total_throughput"]
|
289
|
+
)
|
290
|
+
)
|
291
|
+
print("=" * 50)
|
292
|
+
|
293
|
+
return result
|
294
|
+
|
295
|
+
|
296
|
+
if __name__ == "__main__":
|
297
|
+
parser = argparse.ArgumentParser()
|
298
|
+
ServerArgs.add_cli_args(parser)
|
299
|
+
BenchArgs.add_cli_args(parser)
|
300
|
+
args = parser.parse_args()
|
301
|
+
server_args = ServerArgs.from_cli_args(args)
|
302
|
+
bench_args = BenchArgs.from_cli_args(args)
|
303
|
+
|
304
|
+
logging.basicConfig(
|
305
|
+
level=getattr(logging, server_args.log_level.upper()),
|
306
|
+
format="%(message)s",
|
307
|
+
)
|
308
|
+
|
309
|
+
throughput_test(server_args, bench_args)
|
@@ -421,6 +421,37 @@ def get_tokenizer(
|
|
421
421
|
)
|
422
422
|
|
423
423
|
|
424
|
+
def get_dataset(args, tokenizer):
|
425
|
+
if args.dataset_name == "sharegpt":
|
426
|
+
input_requests = sample_sharegpt_requests(
|
427
|
+
dataset_path=args.dataset_path,
|
428
|
+
num_requests=args.num_prompts,
|
429
|
+
tokenizer=tokenizer,
|
430
|
+
fixed_output_len=args.sharegpt_output_len,
|
431
|
+
)
|
432
|
+
elif args.dataset_name == "random":
|
433
|
+
input_requests = sample_random_requests(
|
434
|
+
input_len=args.random_input_len,
|
435
|
+
output_len=args.random_output_len,
|
436
|
+
num_prompts=args.num_prompts,
|
437
|
+
range_ratio=args.random_range_ratio,
|
438
|
+
tokenizer=tokenizer,
|
439
|
+
dataset_path=args.dataset_path,
|
440
|
+
)
|
441
|
+
elif args.dataset_name == "generated-shared-prefix":
|
442
|
+
input_requests = sample_generated_shared_prefix_requests(
|
443
|
+
num_groups=args.gen_num_groups,
|
444
|
+
prompts_per_group=args.gen_prompts_per_group,
|
445
|
+
system_prompt_len=args.gen_system_prompt_len,
|
446
|
+
question_len=args.gen_question_len,
|
447
|
+
output_len=args.gen_output_len,
|
448
|
+
tokenizer=tokenizer,
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
452
|
+
return input_requests
|
453
|
+
|
454
|
+
|
424
455
|
ASYNC_REQUEST_FUNCS = {
|
425
456
|
"sglang": async_request_sglang_generate,
|
426
457
|
"sglang-native": async_request_sglang_generate,
|
@@ -443,6 +474,8 @@ class BenchmarkMetrics:
|
|
443
474
|
input_throughput: float
|
444
475
|
output_throughput: float
|
445
476
|
output_throughput_retokenized: float
|
477
|
+
total_throughput: float
|
478
|
+
total_throughput_retokenized: float
|
446
479
|
mean_ttft_ms: float
|
447
480
|
median_ttft_ms: float
|
448
481
|
std_ttft_ms: float
|
@@ -590,7 +623,6 @@ def sample_random_requests(
|
|
590
623
|
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
591
624
|
for data in dataset
|
592
625
|
]
|
593
|
-
|
594
626
|
# Shuffle the dataset.
|
595
627
|
random.shuffle(dataset)
|
596
628
|
|
@@ -764,6 +796,9 @@ def calculate_metrics(
|
|
764
796
|
input_throughput=total_input / dur_s,
|
765
797
|
output_throughput=sum(output_lens) / dur_s,
|
766
798
|
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
|
799
|
+
total_throughput=(total_input + sum(output_lens)) / dur_s,
|
800
|
+
total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
|
801
|
+
/ dur_s,
|
767
802
|
mean_ttft_ms=np.mean(ttfts or 0)
|
768
803
|
* 1000, # ttfts is empty if streaming is not supported by backend
|
769
804
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
@@ -881,6 +916,11 @@ async def benchmark(
|
|
881
916
|
"Output token throughput (tok/s):", metrics.output_throughput
|
882
917
|
)
|
883
918
|
)
|
919
|
+
print(
|
920
|
+
"{:<40} {:<10.2f}".format(
|
921
|
+
"Total token throughput (tok/s):", metrics.total_throughput
|
922
|
+
)
|
923
|
+
)
|
884
924
|
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
885
925
|
print(
|
886
926
|
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
@@ -1098,35 +1138,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1098
1138
|
|
1099
1139
|
tokenizer = get_tokenizer(tokenizer_id)
|
1100
1140
|
|
1101
|
-
|
1102
|
-
assert args.random_input_len is None and args.random_output_len is None
|
1103
|
-
input_requests = sample_sharegpt_requests(
|
1104
|
-
dataset_path=args.dataset_path,
|
1105
|
-
num_requests=args.num_prompts,
|
1106
|
-
tokenizer=tokenizer,
|
1107
|
-
fixed_output_len=args.sharegpt_output_len,
|
1108
|
-
)
|
1109
|
-
elif args.dataset_name == "random":
|
1110
|
-
assert args.random_input_len is not None and args.random_output_len is not None
|
1111
|
-
input_requests = sample_random_requests(
|
1112
|
-
input_len=args.random_input_len,
|
1113
|
-
output_len=args.random_output_len,
|
1114
|
-
num_prompts=args.num_prompts,
|
1115
|
-
range_ratio=args.random_range_ratio,
|
1116
|
-
tokenizer=tokenizer,
|
1117
|
-
dataset_path=args.dataset_path,
|
1118
|
-
)
|
1119
|
-
elif args.dataset_name == "generated-shared-prefix":
|
1120
|
-
input_requests = sample_generated_shared_prefix_requests(
|
1121
|
-
num_groups=args.gen_num_groups,
|
1122
|
-
prompts_per_group=args.gen_prompts_per_group,
|
1123
|
-
system_prompt_len=args.gen_system_prompt_len,
|
1124
|
-
question_len=args.gen_question_len,
|
1125
|
-
output_len=args.gen_output_len,
|
1126
|
-
tokenizer=tokenizer,
|
1127
|
-
)
|
1128
|
-
else:
|
1129
|
-
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
1141
|
+
input_requests = get_dataset(args, tokenizer)
|
1130
1142
|
|
1131
1143
|
if not args.multi:
|
1132
1144
|
return asyncio.run(
|
@@ -1229,10 +1241,12 @@ if __name__ == "__main__":
|
|
1229
1241
|
parser.add_argument(
|
1230
1242
|
"--random-input-len",
|
1231
1243
|
type=int,
|
1244
|
+
default=1024,
|
1232
1245
|
help="Number of input tokens per request, used only for random dataset.",
|
1233
1246
|
)
|
1234
1247
|
parser.add_argument(
|
1235
1248
|
"--random-output-len",
|
1249
|
+
default=1024,
|
1236
1250
|
type=int,
|
1237
1251
|
help="Number of output tokens per request, used only for random dataset.",
|
1238
1252
|
)
|
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
"""The baseclass of
|
16
|
+
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
17
17
|
|
18
18
|
from concurrent.futures import Future, ThreadPoolExecutor
|
19
19
|
from dataclasses import dataclass
|
@@ -52,7 +52,7 @@ class BaseGrammarBackend:
|
|
52
52
|
else:
|
53
53
|
entry.value = self.init_value_impl(key)
|
54
54
|
entry.event.set()
|
55
|
-
return entry.value.copy()
|
55
|
+
return entry.value.copy() if entry.value else None
|
56
56
|
|
57
57
|
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
58
58
|
raise NotImplementedError()
|
@@ -62,7 +62,8 @@ class BaseGrammarBackend:
|
|
62
62
|
entry = self.cache.get(key)
|
63
63
|
if not entry or not entry.event.is_set():
|
64
64
|
return None
|
65
|
-
|
65
|
+
val = self.cache[key].value
|
66
|
+
return val.copy() if val else None
|
66
67
|
|
67
68
|
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
68
69
|
return self.executor.submit(self.init_value, key)
|
@@ -19,9 +19,12 @@ import json
|
|
19
19
|
import logging
|
20
20
|
from typing import Dict, List, Optional, Tuple, Union
|
21
21
|
|
22
|
+
import interegular
|
22
23
|
import torch
|
23
24
|
from outlines.fsm.guide import RegexGuide
|
25
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
24
26
|
from outlines.models.transformers import TransformerTokenizer
|
27
|
+
from pydantic import BaseModel
|
25
28
|
|
26
29
|
from sglang.srt.constrained.base_grammar_backend import (
|
27
30
|
BaseGrammarBackend,
|
@@ -32,26 +35,6 @@ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
|
32
35
|
logger = logging.getLogger(__name__)
|
33
36
|
|
34
37
|
|
35
|
-
try:
|
36
|
-
from outlines.fsm.json_schema import build_regex_from_object
|
37
|
-
except ImportError:
|
38
|
-
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
39
|
-
# which only accepts string schema as input.
|
40
|
-
from outlines.fsm.json_schema import build_regex_from_schema
|
41
|
-
from pydantic import BaseModel
|
42
|
-
|
43
|
-
def build_regex_from_object(
|
44
|
-
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
45
|
-
):
|
46
|
-
if isinstance(object, type(BaseModel)):
|
47
|
-
schema = json.dumps(object.model_json_schema())
|
48
|
-
elif isinstance(object, Dict):
|
49
|
-
schema = json.dumps(object)
|
50
|
-
else:
|
51
|
-
schema = object
|
52
|
-
return build_regex_from_schema(schema, whitespace_pattern)
|
53
|
-
|
54
|
-
|
55
38
|
class OutlinesGrammar(BaseGrammarObject):
|
56
39
|
def __init__(
|
57
40
|
self,
|
@@ -147,19 +130,36 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|
147
130
|
key_string,
|
148
131
|
whitespace_pattern=self.whitespace_pattern,
|
149
132
|
)
|
150
|
-
except NotImplementedError as e:
|
133
|
+
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
151
134
|
logger.warning(
|
152
|
-
f"
|
135
|
+
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
153
136
|
)
|
154
|
-
return None
|
137
|
+
return None
|
155
138
|
elif key_type == "regex":
|
156
139
|
regex = key_string
|
157
140
|
else:
|
158
141
|
raise ValueError(f"Invalid key_type: {key_type}")
|
159
142
|
|
160
|
-
|
143
|
+
try:
|
144
|
+
guide = RegexGuide(regex, self.outlines_tokenizer)
|
145
|
+
except interegular.patterns.InvalidSyntax as e:
|
146
|
+
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
|
147
|
+
return None
|
148
|
+
|
161
149
|
if self.allow_jump_forward:
|
162
150
|
jump_forward_map = OutlinesJumpForwardMap(regex)
|
163
151
|
else:
|
164
152
|
jump_forward_map = None
|
165
153
|
return OutlinesGrammar(guide, jump_forward_map)
|
154
|
+
|
155
|
+
|
156
|
+
def build_regex_from_object(
|
157
|
+
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
158
|
+
):
|
159
|
+
if isinstance(object, type(BaseModel)):
|
160
|
+
schema = json.dumps(object.model_json_schema())
|
161
|
+
elif isinstance(object, Dict):
|
162
|
+
schema = json.dumps(object)
|
163
|
+
else:
|
164
|
+
schema = object
|
165
|
+
return build_regex_from_schema(schema, whitespace_pattern)
|
@@ -15,16 +15,29 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Constrained decoding with xgrammar backend."""
|
17
17
|
|
18
|
+
import logging
|
18
19
|
from typing import List, Tuple
|
19
20
|
|
20
21
|
import torch
|
21
|
-
|
22
|
+
|
23
|
+
try:
|
24
|
+
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
25
|
+
|
26
|
+
import_error = None
|
27
|
+
except ImportError as e:
|
28
|
+
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
|
29
|
+
ImportError
|
30
|
+
)
|
31
|
+
import_error = e
|
22
32
|
|
23
33
|
from sglang.srt.constrained.base_grammar_backend import (
|
24
34
|
BaseGrammarBackend,
|
25
35
|
BaseGrammarObject,
|
26
36
|
)
|
27
37
|
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
|
40
|
+
|
28
41
|
MAX_ROLLBACK_TOKENS = 10
|
29
42
|
|
30
43
|
|
@@ -91,15 +104,37 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
91
104
|
vocab_size: int,
|
92
105
|
):
|
93
106
|
super().__init__()
|
107
|
+
|
108
|
+
if import_error:
|
109
|
+
logger.warning(
|
110
|
+
f"Ignore import error for the grammar backend: {import_error}"
|
111
|
+
)
|
112
|
+
self.grammar_cache = None
|
113
|
+
return
|
114
|
+
|
94
115
|
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
95
116
|
self.vocab_size = vocab_size
|
96
117
|
|
97
118
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
119
|
+
if import_error:
|
120
|
+
raise import_error
|
121
|
+
|
98
122
|
key_type, key_string = key
|
99
123
|
if key_type == "json":
|
100
|
-
|
124
|
+
try:
|
125
|
+
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(
|
126
|
+
key_string
|
127
|
+
)
|
128
|
+
except RuntimeError as e:
|
129
|
+
logging.warning(
|
130
|
+
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
131
|
+
)
|
132
|
+
return None
|
101
133
|
elif key_type == "regex":
|
102
|
-
|
134
|
+
logger.warning(
|
135
|
+
"regex hasn't been supported by xgrammar yet. This is skipped."
|
136
|
+
)
|
137
|
+
return None
|
103
138
|
else:
|
104
139
|
raise ValueError(f"Invalid key_type: {key_type}")
|
105
140
|
|
@@ -111,4 +146,5 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
111
146
|
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
112
147
|
|
113
148
|
def reset(self):
|
114
|
-
self.grammar_cache
|
149
|
+
if self.grammar_cache:
|
150
|
+
self.grammar_cache.clear()
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Optional
|
1
|
+
from typing import Callable, Optional
|
2
2
|
|
3
3
|
import torch
|
4
4
|
from torch.nn import functional as F
|
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
|
|
98
98
|
renormalize: bool,
|
99
99
|
topk_group: Optional[int] = None,
|
100
100
|
num_expert_group: Optional[int] = None,
|
101
|
+
custom_routing_function: Optional[Callable] = None,
|
101
102
|
) -> torch.Tensor:
|
103
|
+
assert custom_routing_function is None
|
102
104
|
topk_weights, topk_ids = select_experts_native(
|
103
105
|
hidden_states=x,
|
104
106
|
router_logits=router_logits,
|
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
|
|
114
116
|
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
115
117
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
116
118
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
117
|
-
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
|
119
|
+
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|