sglang 0.3.6.post3__tar.gz → 0.4.0.post1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/PKG-INFO +5 -4
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/README.md +3 -2
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/pyproject.toml +2 -2
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/__init__.py +1 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/bench_one_batch.py +4 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/bench_serving.py +13 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/check_env.py +1 -1
- sglang-0.4.0.post1/sglang/srt/_custom_ops.py +118 -0
- sglang-0.4.0.post1/sglang/srt/configs/device_config.py +17 -0
- sglang-0.4.0.post1/sglang/srt/configs/load_config.py +84 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/configs/model_config.py +161 -4
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/configs/qwen2vl.py +5 -8
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/constrained/outlines_backend.py +11 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/constrained/outlines_jump_forward.py +8 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang-0.4.0.post1/sglang/srt/distributed/__init__.py +3 -0
- sglang-0.4.0.post1/sglang/srt/distributed/communication_op.py +34 -0
- sglang-0.4.0.post1/sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang-0.4.0.post1/sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang-0.4.0.post1/sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang-0.4.0.post1/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang-0.4.0.post1/sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang-0.4.0.post1/sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang-0.4.0.post1/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang-0.4.0.post1/sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang-0.4.0.post1/sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang-0.4.0.post1/sglang/srt/distributed/parallel_state.py +1275 -0
- sglang-0.4.0.post1/sglang/srt/distributed/utils.py +223 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/hf_transformers_utils.py +37 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/attention/__init__.py +5 -2
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang-0.4.0.post1/sglang/srt/layers/attention/torch_native_backend.py +299 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/attention/triton_backend.py +22 -8
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang-0.4.0.post1/sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang-0.4.0.post1/sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang-0.4.0.post1/sglang/srt/layers/ep_moe/layer.py +661 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/fused_moe_patch.py +20 -11
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/linear.py +1 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/logits_processor.py +17 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/quantization/__init__.py +36 -2
- sglang-0.4.0.post1/sglang/srt/layers/quantization/fp8.py +559 -0
- sglang-0.4.0.post1/sglang/srt/layers/quantization/fp8_utils.py +27 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/radix_attention.py +4 -2
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/sampler.py +2 -0
- sglang-0.4.0.post1/sglang/srt/layers/torchao_utils.py +73 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/lora/lora.py +1 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/io_struct.py +48 -2
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/schedule_batch.py +19 -14
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/schedule_policy.py +7 -4
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/scheduler.py +145 -85
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/tokenizer_manager.py +166 -68
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/tp_worker.py +36 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/mem_cache/memory_pool.py +5 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/model_executor/forward_batch_info.py +9 -4
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/model_executor/model_runner.py +146 -153
- sglang-0.4.0.post1/sglang/srt/model_loader/__init__.py +34 -0
- sglang-0.4.0.post1/sglang/srt/model_loader/loader.py +1139 -0
- sglang-0.4.0.post1/sglang/srt/model_loader/utils.py +41 -0
- sglang-0.4.0.post1/sglang/srt/model_loader/weight_utils.py +640 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/model_parallel.py +1 -5
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/baichuan.py +9 -10
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/chatglm.py +6 -15
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/commandr.py +4 -5
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/dbrx.py +2 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/deepseek.py +4 -11
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/deepseek_v2.py +90 -18
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/exaone.py +2 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/gemma.py +2 -6
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/gemma2.py +3 -14
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/gemma2_reward.py +0 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/gpt2.py +5 -12
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/gpt_bigcode.py +6 -22
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/grok.py +3 -8
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/internlm2.py +2 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/internlm2_reward.py +0 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/llama.py +96 -31
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/llama_classification.py +1 -2
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/llama_embedding.py +1 -2
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/llama_reward.py +2 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/llava.py +1 -4
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/llavavid.py +1 -2
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/minicpm.py +4 -7
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/minicpm3.py +6 -19
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/mixtral.py +24 -14
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/mixtral_quant.py +2 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/mllama.py +3 -7
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/olmo.py +2 -8
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/olmo2.py +0 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/olmoe.py +3 -5
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/phi3_small.py +8 -13
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/qwen.py +2 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/qwen2.py +10 -9
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/qwen2_moe.py +4 -16
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/qwen2_vl.py +2 -6
- sglang-0.4.0.post1/sglang/srt/models/registry.py +99 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/stablelm.py +2 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/torch_native_llama.py +6 -17
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/xverse.py +2 -4
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/xverse_moe.py +4 -11
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/yivl.py +2 -3
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/openai_api/adapter.py +9 -5
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/openai_api/protocol.py +1 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/sampling/sampling_batch_info.py +9 -8
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/server.py +270 -173
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/server_args.py +102 -29
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/utils.py +295 -28
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/test_utils.py +7 -0
- sglang-0.4.0.post1/sglang/version.py +1 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang.egg-info/PKG-INFO +5 -4
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang.egg-info/SOURCES.txt +27 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang.egg-info/requires.txt +1 -1
- sglang-0.3.6.post3/sglang/srt/layers/torchao_utils.py +0 -95
- sglang-0.3.6.post3/sglang/version.py +0 -1
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/LICENSE +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/setup.cfg +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/api.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/bench_latency.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/bench_offline_throughput.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/bench_one_batch_server.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/global_config.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/__init__.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/backend/__init__.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/backend/anthropic.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/backend/base_backend.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/backend/litellm.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/backend/openai.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/backend/runtime_endpoint.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/backend/vertexai.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/chat_template.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/choices.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/compiler.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/interpreter.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/ir.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/lang/tracer.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/launch_server.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/launch_server_llavavid.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/configs/__init__.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/configs/exaone.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/constrained/__init__.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/constrained/base_grammar_backend.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/conversation.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/activation.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/attention/triton_ops/decode_attention.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/attention/triton_ops/prefill_attention.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/custom_op_util.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/fused_moe_triton/__init__.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/fused_moe_triton/fused_moe.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/fused_moe_triton/layer.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/layernorm.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/pooler.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/quantization/base_config.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/layers/rotary_embedding.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/lora/lora_config.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/lora/lora_manager.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/data_parallel_controller.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/detokenizer_manager.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/image_processor.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/managers/session_controller.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/mem_cache/base_prefix_cache.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/mem_cache/chunk_cache.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/mem_cache/flush_cache.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/mem_cache/radix_cache.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/metrics/collector.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/metrics/func_timer.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/mm_utils.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/models/mistral.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/sampling/penaltylib/__init__.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/sampling/penaltylib/orchestrator.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/srt/sampling/sampling_params.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/few_shot_gsm8k.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/few_shot_gsm8k_engine.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/run_eval.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/runners.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/simple_eval_common.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/simple_eval_gpqa.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/simple_eval_humaneval.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/simple_eval_math.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/simple_eval_mgsm.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/simple_eval_mmlu.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/srt/sampling/penaltylib/utils.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/test_activation.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/test_layernorm.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/test/test_programs.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang/utils.py +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang.egg-info/dependency_links.txt +0 -0
- {sglang-0.3.6.post3 → sglang-0.4.0.post1}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0.post1
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -239,7 +239,7 @@ Requires-Dist: xgrammar>=0.1.4; extra == "runtime-common"
|
|
239
239
|
Provides-Extra: srt
|
240
240
|
Requires-Dist: sglang[runtime_common]; extra == "srt"
|
241
241
|
Requires-Dist: torch; extra == "srt"
|
242
|
-
Requires-Dist: vllm
|
242
|
+
Requires-Dist: vllm<=0.6.4.post1,>=0.6.3.post1; extra == "srt"
|
243
243
|
Requires-Dist: cuda-python; extra == "srt"
|
244
244
|
Requires-Dist: flashinfer>=0.1.6; extra == "srt"
|
245
245
|
Provides-Extra: srt-hip
|
@@ -315,6 +315,7 @@ Requires-Dist: sglang[test]; extra == "dev-hpu"
|
|
315
315
|
[**Join Bi-Weekly Development Meeting**](https://docs.google.com/document/d/1xEow4eIM152xNcRxqZz9VEcOiTQo8-CEuuQ5qTmkt-E/edit?usp=sharing) | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
|
316
316
|
|
317
317
|
## News
|
318
|
+
- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
|
318
319
|
- [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
|
319
320
|
- [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
|
320
321
|
- [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
|
@@ -346,13 +347,13 @@ The core features include:
|
|
346
347
|
- [Frontend: Structured Generation Language (SGLang)](https://sgl-project.github.io/frontend/frontend.html)
|
347
348
|
|
348
349
|
## Benchmark And Performance
|
349
|
-
Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)
|
350
|
+
Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)
|
350
351
|
|
351
352
|
## Roadmap
|
352
353
|
[Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487)
|
353
354
|
|
354
355
|
## Adoption and Sponsorship
|
355
|
-
The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI.
|
356
|
+
The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI.
|
356
357
|
|
357
358
|
## Acknowledgment and Citation
|
358
359
|
We learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).
|
@@ -16,6 +16,7 @@
|
|
16
16
|
[**Join Bi-Weekly Development Meeting**](https://docs.google.com/document/d/1xEow4eIM152xNcRxqZz9VEcOiTQo8-CEuuQ5qTmkt-E/edit?usp=sharing) | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
|
17
17
|
|
18
18
|
## News
|
19
|
+
- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
|
19
20
|
- [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
|
20
21
|
- [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
|
21
22
|
- [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
|
@@ -47,13 +48,13 @@ The core features include:
|
|
47
48
|
- [Frontend: Structured Generation Language (SGLang)](https://sgl-project.github.io/frontend/frontend.html)
|
48
49
|
|
49
50
|
## Benchmark And Performance
|
50
|
-
Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)
|
51
|
+
Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)
|
51
52
|
|
52
53
|
## Roadmap
|
53
54
|
[Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487)
|
54
55
|
|
55
56
|
## Adoption and Sponsorship
|
56
|
-
The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI.
|
57
|
+
The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI.
|
57
58
|
|
58
59
|
## Acknowledgment and Citation
|
59
60
|
We learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "sglang"
|
7
|
-
version = "0.
|
7
|
+
version = "0.4.0.post1"
|
8
8
|
description = "SGLang is yet another fast serving framework for large language models and vision language models."
|
9
9
|
readme = "README.md"
|
10
10
|
requires-python = ">=3.8"
|
@@ -23,7 +23,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
|
|
23
23
|
"psutil", "pydantic", "python-multipart",
|
24
24
|
"pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop",
|
25
25
|
"xgrammar>=0.1.4"]
|
26
|
-
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1", "cuda-python", "flashinfer>=0.1.6"]
|
26
|
+
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6"]
|
27
27
|
|
28
28
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
29
29
|
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank):
|
|
111
111
|
model_config = ModelConfig(
|
112
112
|
server_args.model_path,
|
113
113
|
trust_remote_code=server_args.trust_remote_code,
|
114
|
+
revision=server_args.revision,
|
114
115
|
context_length=server_args.context_length,
|
115
116
|
model_override_args=server_args.json_model_override_args,
|
117
|
+
is_embedding=server_args.is_embedding,
|
118
|
+
dtype=server_args.dtype,
|
119
|
+
quantization=server_args.quantization,
|
116
120
|
)
|
117
121
|
model_runner = ModelRunner(
|
118
122
|
model_config=model_config,
|
@@ -51,6 +51,7 @@ class RequestFuncInput:
|
|
51
51
|
prompt_len: int
|
52
52
|
output_len: int
|
53
53
|
model: str
|
54
|
+
lora_name: str
|
54
55
|
extra_request_body: Dict[str, Any]
|
55
56
|
|
56
57
|
|
@@ -319,6 +320,7 @@ async def async_request_sglang_generate(
|
|
319
320
|
"ignore_eos": not args.disable_ignore_eos,
|
320
321
|
},
|
321
322
|
"stream": not args.disable_stream,
|
323
|
+
"lora_path": request_func_input.lora_name,
|
322
324
|
**request_func_input.extra_request_body,
|
323
325
|
}
|
324
326
|
headers = {}
|
@@ -884,6 +886,7 @@ async def benchmark(
|
|
884
886
|
request_rate: float,
|
885
887
|
max_concurrency: Optional[int],
|
886
888
|
disable_tqdm: bool,
|
889
|
+
lora_name: str,
|
887
890
|
extra_request_body: Dict[str, Any],
|
888
891
|
profile: bool,
|
889
892
|
):
|
@@ -909,6 +912,7 @@ async def benchmark(
|
|
909
912
|
api_url=api_url,
|
910
913
|
prompt_len=test_prompt_len,
|
911
914
|
output_len=test_output_len,
|
915
|
+
lora_name=lora_name,
|
912
916
|
extra_request_body=extra_request_body,
|
913
917
|
)
|
914
918
|
test_output = await request_func(request_func_input=test_input)
|
@@ -942,6 +946,7 @@ async def benchmark(
|
|
942
946
|
api_url=api_url,
|
943
947
|
prompt_len=prompt_len,
|
944
948
|
output_len=output_len,
|
949
|
+
lora_name=lora_name,
|
945
950
|
extra_request_body=extra_request_body,
|
946
951
|
)
|
947
952
|
tasks.append(
|
@@ -1247,6 +1252,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1247
1252
|
request_rate=args.request_rate,
|
1248
1253
|
max_concurrency=args.max_concurrency,
|
1249
1254
|
disable_tqdm=args.disable_tqdm,
|
1255
|
+
lora_name=args.lora_name,
|
1250
1256
|
extra_request_body=extra_request_body,
|
1251
1257
|
profile=args.profile,
|
1252
1258
|
)
|
@@ -1267,6 +1273,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1267
1273
|
request_rate=rate,
|
1268
1274
|
max_concurrency=args.max_concurrency,
|
1269
1275
|
disable_tqdm=args.disable_tqdm,
|
1276
|
+
lora_name=args.lora_name,
|
1270
1277
|
extra_request_body=extra_request_body,
|
1271
1278
|
profile=args.profile,
|
1272
1279
|
)
|
@@ -1451,5 +1458,11 @@ if __name__ == "__main__":
|
|
1451
1458
|
help="Use Torch Profiler. The endpoint must be launched with "
|
1452
1459
|
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
1453
1460
|
)
|
1461
|
+
parser.add_argument(
|
1462
|
+
"--lora-name",
|
1463
|
+
type=str,
|
1464
|
+
default=None,
|
1465
|
+
help="The name of LoRA adapter",
|
1466
|
+
)
|
1454
1467
|
args = parser.parse_args()
|
1455
1468
|
run_benchmark(args)
|
@@ -0,0 +1,118 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
|
2
|
+
import contextlib
|
3
|
+
import functools
|
4
|
+
import importlib
|
5
|
+
import logging
|
6
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import torch.library
|
10
|
+
|
11
|
+
from sglang.srt.utils import is_hpu
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
if not is_hpu():
|
16
|
+
try:
|
17
|
+
import custom_ar
|
18
|
+
except ImportError as e:
|
19
|
+
logger.warning("Failed to import from custom_ar with %r", e)
|
20
|
+
|
21
|
+
|
22
|
+
def hint_on_error(fn):
|
23
|
+
|
24
|
+
@functools.wraps(fn)
|
25
|
+
def wrapper(*args, **kwargs):
|
26
|
+
try:
|
27
|
+
return fn(*args, **kwargs)
|
28
|
+
|
29
|
+
except NotImplementedError as e:
|
30
|
+
msg = (
|
31
|
+
"Error in calling custom op %s: %s\n"
|
32
|
+
"Not implemented or built, mostly likely because the current current device "
|
33
|
+
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
|
34
|
+
"incorrectly while building)"
|
35
|
+
)
|
36
|
+
logger.error(msg, fn.__name__, e)
|
37
|
+
raise NotImplementedError(msg % (fn.__name__, e)) from e
|
38
|
+
except AttributeError as e:
|
39
|
+
msg = (
|
40
|
+
"Error in calling custom op %s: %s\n"
|
41
|
+
"Possibly you have built or installed an obsolete version of vllm.\n"
|
42
|
+
"Please try a clean build and install of vllm,"
|
43
|
+
"or remove old built files such as vllm/*cpython*.so and build/ ."
|
44
|
+
)
|
45
|
+
logger.error(msg, fn.__name__, e)
|
46
|
+
raise e
|
47
|
+
|
48
|
+
return wrapper
|
49
|
+
|
50
|
+
|
51
|
+
# custom ar
|
52
|
+
def init_custom_ar(
|
53
|
+
ipc_tensors: List[torch.Tensor],
|
54
|
+
rank_data: torch.Tensor,
|
55
|
+
rank: int,
|
56
|
+
full_nvlink: bool,
|
57
|
+
) -> int:
|
58
|
+
return torch.ops._C_vllm_ar.init_custom_ar(
|
59
|
+
ipc_tensors, rank_data, rank, full_nvlink
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
def all_reduce(
|
64
|
+
fa: int,
|
65
|
+
inp: torch.Tensor,
|
66
|
+
out: torch.Tensor,
|
67
|
+
reg_buffer: int,
|
68
|
+
reg_buffer_sz_bytes: int,
|
69
|
+
) -> None:
|
70
|
+
torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
|
71
|
+
|
72
|
+
|
73
|
+
def dispose(fa: int) -> None:
|
74
|
+
torch.ops._C_vllm_ar.dispose(fa)
|
75
|
+
|
76
|
+
|
77
|
+
def meta_size() -> int:
|
78
|
+
return torch.ops._C_vllm_ar.meta_size()
|
79
|
+
|
80
|
+
|
81
|
+
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
|
82
|
+
return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)
|
83
|
+
|
84
|
+
|
85
|
+
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
86
|
+
return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)
|
87
|
+
|
88
|
+
|
89
|
+
def register_graph_buffers(
|
90
|
+
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
91
|
+
) -> None:
|
92
|
+
torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)
|
93
|
+
|
94
|
+
|
95
|
+
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
|
96
|
+
# TODO: remove this in v0.6.0
|
97
|
+
names_and_values = globals()
|
98
|
+
names_and_values_to_update = {}
|
99
|
+
# prepare variables to avoid dict size change during iteration
|
100
|
+
k, v, arg = None, None, None
|
101
|
+
fn_type = type(lambda x: x)
|
102
|
+
for k, v in names_and_values.items():
|
103
|
+
# find functions that are defined in this file and have torch.Tensor
|
104
|
+
# in their annotations. `arg == "torch.Tensor"` is used to handle
|
105
|
+
# the case when users use `import __annotations__` to turn type
|
106
|
+
# hints into strings.
|
107
|
+
if (
|
108
|
+
isinstance(v, fn_type)
|
109
|
+
and v.__code__.co_filename == __file__
|
110
|
+
and any(
|
111
|
+
arg is torch.Tensor or arg == "torch.Tensor"
|
112
|
+
for arg in v.__annotations__.values()
|
113
|
+
)
|
114
|
+
):
|
115
|
+
names_and_values_to_update[k] = hint_on_error(v)
|
116
|
+
|
117
|
+
names_and_values.update(names_and_values_to_update)
|
118
|
+
del names_and_values_to_update, names_and_values, v, k, fn_type
|
@@ -0,0 +1,17 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
|
9
|
+
class DeviceConfig:
|
10
|
+
device: Optional[torch.device]
|
11
|
+
|
12
|
+
def __init__(self, device: str = "cuda") -> None:
|
13
|
+
if device in ["cuda", "xpu", "hpu"]:
|
14
|
+
self.device_type = device
|
15
|
+
else:
|
16
|
+
raise RuntimeError(f"Not supported device type: {device}")
|
17
|
+
self.device = torch.device(self.device_type)
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
2
|
+
import enum
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
from typing import List, Optional, Union
|
7
|
+
|
8
|
+
from sglang.srt.utils import is_hip
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class LoadFormat(str, enum.Enum):
|
14
|
+
AUTO = "auto"
|
15
|
+
PT = "pt"
|
16
|
+
SAFETENSORS = "safetensors"
|
17
|
+
NPCACHE = "npcache"
|
18
|
+
DUMMY = "dummy"
|
19
|
+
SHARDED_STATE = "sharded_state"
|
20
|
+
GGUF = "gguf"
|
21
|
+
BITSANDBYTES = "bitsandbytes"
|
22
|
+
MISTRAL = "mistral"
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class LoadConfig:
|
27
|
+
"""
|
28
|
+
download_dir: Directory to download and load the weights, default to the
|
29
|
+
default cache directory of huggingface.
|
30
|
+
load_format: The format of the model weights to load:
|
31
|
+
"auto" will try to load the weights in the safetensors format and
|
32
|
+
fall back to the pytorch bin format if safetensors format is
|
33
|
+
not available.
|
34
|
+
"pt" will load the weights in the pytorch bin format.
|
35
|
+
"safetensors" will load the weights in the safetensors format.
|
36
|
+
"npcache" will load the weights in pytorch format and store
|
37
|
+
a numpy cache to speed up the loading.
|
38
|
+
"dummy" will initialize the weights with random values, which is
|
39
|
+
mainly for profiling.
|
40
|
+
"bitsandbytes" will load nf4 type weights.
|
41
|
+
ignore_patterns: The list of patterns to ignore when loading the model.
|
42
|
+
Default to "original/**/*" to avoid repeated loading of llama's
|
43
|
+
checkpoints.
|
44
|
+
|
45
|
+
"""
|
46
|
+
|
47
|
+
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
48
|
+
download_dir: Optional[str] = None
|
49
|
+
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
50
|
+
ignore_patterns: Optional[Union[List[str], str]] = None
|
51
|
+
|
52
|
+
def __post_init__(self):
|
53
|
+
model_loader_extra_config = self.model_loader_extra_config or {}
|
54
|
+
if isinstance(model_loader_extra_config, str):
|
55
|
+
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
56
|
+
self._verify_load_format()
|
57
|
+
|
58
|
+
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
59
|
+
logger.info(
|
60
|
+
"Ignoring the following patterns when downloading weights: %s",
|
61
|
+
self.ignore_patterns,
|
62
|
+
)
|
63
|
+
else:
|
64
|
+
self.ignore_patterns = ["original/**/*"]
|
65
|
+
|
66
|
+
def _verify_load_format(self) -> None:
|
67
|
+
if not isinstance(self.load_format, str):
|
68
|
+
return
|
69
|
+
|
70
|
+
load_format = self.load_format.lower()
|
71
|
+
self.load_format = LoadFormat(load_format)
|
72
|
+
|
73
|
+
rocm_not_supported_load_format: List[str] = []
|
74
|
+
if is_hip() and load_format in rocm_not_supported_load_format:
|
75
|
+
rocm_supported_load_format = [
|
76
|
+
f
|
77
|
+
for f in LoadFormat.__members__
|
78
|
+
if (f not in rocm_not_supported_load_format)
|
79
|
+
]
|
80
|
+
raise ValueError(
|
81
|
+
f"load format '{load_format}' is not supported in ROCm. "
|
82
|
+
f"Supported load formats are "
|
83
|
+
f"{rocm_supported_load_format}"
|
84
|
+
)
|
@@ -15,12 +15,14 @@
|
|
15
15
|
import json
|
16
16
|
import logging
|
17
17
|
from enum import IntEnum, auto
|
18
|
-
from typing import List, Optional
|
18
|
+
from typing import List, Optional, Union
|
19
19
|
|
20
|
+
import torch
|
20
21
|
from transformers import PretrainedConfig
|
21
22
|
|
22
23
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
23
|
-
from sglang.srt.
|
24
|
+
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
25
|
+
from sglang.srt.utils import get_bool_env_var, is_hip
|
24
26
|
|
25
27
|
logger = logging.getLogger(__name__)
|
26
28
|
|
@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
|
|
33
35
|
class ModelConfig:
|
34
36
|
def __init__(
|
35
37
|
self,
|
36
|
-
|
38
|
+
model_path: str,
|
37
39
|
trust_remote_code: bool = True,
|
38
40
|
revision: Optional[str] = None,
|
39
41
|
context_length: Optional[int] = None,
|
40
42
|
model_override_args: Optional[dict] = None,
|
41
43
|
is_embedding: Optional[bool] = None,
|
44
|
+
dtype: str = "auto",
|
45
|
+
quantization: Optional[str] = None,
|
42
46
|
) -> None:
|
47
|
+
self.model_path = model_path
|
48
|
+
self.revision = revision
|
49
|
+
self.quantization = quantization
|
43
50
|
# Parse args
|
44
51
|
self.model_override_args = json.loads(model_override_args)
|
45
52
|
self.hf_config = get_config(
|
46
|
-
|
53
|
+
model_path,
|
47
54
|
trust_remote_code=trust_remote_code,
|
48
55
|
revision=revision,
|
49
56
|
model_override_args=self.model_override_args,
|
@@ -56,6 +63,7 @@ class ModelConfig:
|
|
56
63
|
)
|
57
64
|
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
58
65
|
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
66
|
+
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
59
67
|
|
60
68
|
# Derive context length
|
61
69
|
derived_context_len = get_context_length(self.hf_text_config)
|
@@ -116,6 +124,8 @@ class ModelConfig:
|
|
116
124
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
117
125
|
self.vocab_size = self.hf_text_config.vocab_size
|
118
126
|
|
127
|
+
self._verify_quantization()
|
128
|
+
|
119
129
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
120
130
|
def get_total_num_kv_heads(self) -> int:
|
121
131
|
"""Returns the total number of KV heads."""
|
@@ -174,6 +184,86 @@ class ModelConfig:
|
|
174
184
|
# parallel size so each GPU has at least one KV head.
|
175
185
|
return max(1, total_num_kv_heads // tensor_parallel_size)
|
176
186
|
|
187
|
+
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
188
|
+
def _parse_quant_hf_config(self):
|
189
|
+
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
190
|
+
if quant_cfg is None:
|
191
|
+
# compressed-tensors uses a "compression_config" key
|
192
|
+
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
193
|
+
return quant_cfg
|
194
|
+
|
195
|
+
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
196
|
+
def _verify_quantization(self) -> None:
|
197
|
+
supported_quantization = [*QUANTIZATION_METHODS]
|
198
|
+
rocm_supported_quantization = [
|
199
|
+
"awq",
|
200
|
+
"gptq",
|
201
|
+
"fp8",
|
202
|
+
"compressed_tensors",
|
203
|
+
"compressed-tensors",
|
204
|
+
"fbgemm_fp8",
|
205
|
+
]
|
206
|
+
optimized_quantization_methods = [
|
207
|
+
"fp8",
|
208
|
+
"marlin",
|
209
|
+
"modelopt",
|
210
|
+
"gptq_marlin_24",
|
211
|
+
"gptq_marlin",
|
212
|
+
"awq_marlin",
|
213
|
+
"fbgemm_fp8",
|
214
|
+
"compressed_tensors",
|
215
|
+
"compressed-tensors",
|
216
|
+
"experts_int8",
|
217
|
+
]
|
218
|
+
if self.quantization is not None:
|
219
|
+
self.quantization = self.quantization.lower()
|
220
|
+
|
221
|
+
# Parse quantization method from the HF model config, if available.
|
222
|
+
quant_cfg = self._parse_quant_hf_config()
|
223
|
+
|
224
|
+
if quant_cfg is not None:
|
225
|
+
quant_method = quant_cfg.get("quant_method", "").lower()
|
226
|
+
|
227
|
+
# Detect which checkpoint is it
|
228
|
+
for _, method in QUANTIZATION_METHODS.items():
|
229
|
+
quantization_override = method.override_quantization_method(
|
230
|
+
quant_cfg, self.quantization
|
231
|
+
)
|
232
|
+
if quantization_override:
|
233
|
+
quant_method = quantization_override
|
234
|
+
self.quantization = quantization_override
|
235
|
+
break
|
236
|
+
|
237
|
+
# Verify quantization configurations.
|
238
|
+
if self.quantization is None:
|
239
|
+
self.quantization = quant_method
|
240
|
+
elif self.quantization != quant_method:
|
241
|
+
raise ValueError(
|
242
|
+
"Quantization method specified in the model config "
|
243
|
+
f"({quant_method}) does not match the quantization "
|
244
|
+
f"method specified in the `quantization` argument "
|
245
|
+
f"({self.quantization})."
|
246
|
+
)
|
247
|
+
|
248
|
+
if self.quantization is not None:
|
249
|
+
if self.quantization not in supported_quantization:
|
250
|
+
raise ValueError(
|
251
|
+
f"Unknown quantization method: {self.quantization}. Must "
|
252
|
+
f"be one of {supported_quantization}."
|
253
|
+
)
|
254
|
+
if is_hip() and self.quantization not in rocm_supported_quantization:
|
255
|
+
raise ValueError(
|
256
|
+
f"{self.quantization} quantization is currently not "
|
257
|
+
f"supported in ROCm."
|
258
|
+
)
|
259
|
+
if self.quantization not in optimized_quantization_methods:
|
260
|
+
logger.warning(
|
261
|
+
"%s quantization is not fully "
|
262
|
+
"optimized yet. The speed can be slower than "
|
263
|
+
"non-quantized models.",
|
264
|
+
self.quantization,
|
265
|
+
)
|
266
|
+
|
177
267
|
|
178
268
|
def get_hf_text_config(config: PretrainedConfig):
|
179
269
|
"""Get the "sub" config relevant to llm for multi modal models.
|
@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
183
273
|
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
184
274
|
# We support non-hf version of llava models, so we do not want to
|
185
275
|
# read the wrong values from the unused default text_config.
|
276
|
+
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
|
277
|
+
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
|
278
|
+
setattr(config, "torch_dtype", torch.float16)
|
186
279
|
return config
|
187
280
|
|
188
281
|
if hasattr(config, "text_config"):
|
@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
195
288
|
return config
|
196
289
|
|
197
290
|
|
291
|
+
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
292
|
+
_STR_DTYPE_TO_TORCH_DTYPE = {
|
293
|
+
"half": torch.float16,
|
294
|
+
"float16": torch.float16,
|
295
|
+
"float": torch.float32,
|
296
|
+
"float32": torch.float32,
|
297
|
+
"bfloat16": torch.bfloat16,
|
298
|
+
}
|
299
|
+
|
300
|
+
|
301
|
+
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
302
|
+
def _get_and_verify_dtype(
|
303
|
+
config: PretrainedConfig,
|
304
|
+
dtype: Union[str, torch.dtype],
|
305
|
+
) -> torch.dtype:
|
306
|
+
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
307
|
+
# because config.torch_dtype can be None.
|
308
|
+
config_dtype = getattr(config, "torch_dtype", None)
|
309
|
+
if config_dtype is None:
|
310
|
+
config_dtype = torch.float32
|
311
|
+
|
312
|
+
if isinstance(dtype, str):
|
313
|
+
dtype = dtype.lower()
|
314
|
+
if dtype == "auto":
|
315
|
+
if config_dtype == torch.float32:
|
316
|
+
if config.model_type == "gemma2":
|
317
|
+
logger.info(
|
318
|
+
"For Gemma 2, we downcast float32 to bfloat16 instead "
|
319
|
+
"of float16 by default. Please specify `dtype` if you "
|
320
|
+
"want to use float16."
|
321
|
+
)
|
322
|
+
torch_dtype = torch.bfloat16
|
323
|
+
else:
|
324
|
+
# Following the common practice, we use float16 for float32
|
325
|
+
# models.
|
326
|
+
torch_dtype = torch.float16
|
327
|
+
else:
|
328
|
+
torch_dtype = config_dtype
|
329
|
+
else:
|
330
|
+
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
331
|
+
raise ValueError(f"Unknown dtype: {dtype}")
|
332
|
+
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
333
|
+
elif isinstance(dtype, torch.dtype):
|
334
|
+
torch_dtype = dtype
|
335
|
+
else:
|
336
|
+
raise ValueError(f"Unknown dtype: {dtype}")
|
337
|
+
|
338
|
+
# Verify the dtype.
|
339
|
+
if torch_dtype != config_dtype:
|
340
|
+
if torch_dtype == torch.float32:
|
341
|
+
# Upcasting to float32 is allowed.
|
342
|
+
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
|
343
|
+
pass
|
344
|
+
elif config_dtype == torch.float32:
|
345
|
+
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
346
|
+
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
|
347
|
+
pass
|
348
|
+
else:
|
349
|
+
# Casting between float16 and bfloat16 is allowed with a warning.
|
350
|
+
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
|
351
|
+
|
352
|
+
return torch_dtype
|
353
|
+
|
354
|
+
|
198
355
|
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
199
356
|
# We have two ways to determine whether a model is a generative model.
|
200
357
|
# 1. Check the model architectue
|
@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
|
|
121
121
|
self.attention_dropout = attention_dropout
|
122
122
|
self.rope_scaling = rope_scaling
|
123
123
|
|
124
|
-
# NOTE: the
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
# self.rope_scaling["type"] = "default"
|
130
|
-
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
131
|
-
# rope_config_validation(self)
|
124
|
+
# NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
|
125
|
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
126
|
+
if self.rope_scaling["type"] == "mrope":
|
127
|
+
self.rope_scaling["type"] = "default"
|
128
|
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
132
129
|
|
133
130
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|