sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__
29
29
  from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
30
30
 
31
31
 
32
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x95\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x08 \x01(\r\"\x9b\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x0b \x01(\rB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3')
32
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xd0\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\t\n\x01n\x18\x11 \x01(\x05\x12\x16\n\x0emin_new_tokens\x18\x12 \x01(\x05\x12\x12\n\nignore_eos\x18\x13 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x14 \x01(\x08\x12\x1c\n\x0fstream_interval\x18\x15 \x01(\x05H\x02\x88\x01\x01\x12H\n\nlogit_bias\x18\x16 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x17 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokensB\x12\n\x10_stream_interval\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x95\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x08 \x01(\r\"\x9b\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x0b \x01(\rB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x15\n\x13GetModelInfoRequest\"\xea\x02\n\x14GetModelInfoResponse\x12\x12\n\nmodel_path\x18\x01 \x01(\t\x12\x16\n\x0etokenizer_path\x18\x02 \x01(\t\x12\x15\n\ris_generation\x18\x03 \x01(\x08\x12!\n\x19preferred_sampling_params\x18\x04 \x01(\t\x12\x16\n\x0eweight_version\x18\x05 \x01(\t\x12\x19\n\x11served_model_name\x18\x06 \x01(\t\x12\x1a\n\x12max_context_length\x18\x07 \x01(\x05\x12\x12\n\nvocab_size\x18\x08 \x01(\x05\x12\x17\n\x0fsupports_vision\x18\t \x01(\x08\x12\x12\n\nmodel_type\x18\n \x01(\t\x12\x15\n\reos_token_ids\x18\x0b \x03(\x05\x12\x14\n\x0cpad_token_id\x18\x0c \x01(\x05\x12\x14\n\x0c\x62os_token_id\x18\r \x01(\x05\x12\x19\n\x11max_req_input_len\x18\x0e \x01(\x05\"\x16\n\x14GetServerInfoRequest\"\xb7\x02\n\x15GetServerInfoResponse\x12,\n\x0bserver_args\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12/\n\x0escheduler_info\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x17\n\x0f\x61\x63tive_requests\x18\x03 \x01(\x05\x12\x11\n\tis_paused\x18\x04 \x01(\x08\x12\x1e\n\x16last_receive_timestamp\x18\x05 \x01(\x01\x12\x16\n\x0euptime_seconds\x18\x06 \x01(\x01\x12\x16\n\x0esglang_version\x18\x07 \x01(\t\x12\x13\n\x0bserver_type\x18\x08 \x01(\t\x12.\n\nstart_time\x18\t \x01(\x0b\x32\x1a.google.protobuf.Timestamp2\xd3\x04\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponse\x12g\n\x0cGetModelInfo\x12*.sglang.grpc.scheduler.GetModelInfoRequest\x1a+.sglang.grpc.scheduler.GetModelInfoResponse\x12j\n\rGetServerInfo\x12+.sglang.grpc.scheduler.GetServerInfoRequest\x1a,.sglang.grpc.scheduler.GetServerInfoResponseb\x06proto3')
33
33
 
34
34
  _globals = globals()
35
35
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -39,73 +39,81 @@ if not _descriptor._USE_C_DESCRIPTORS:
39
39
  _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._loaded_options = None
40
40
  _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_options = b'8\001'
41
41
  _globals['_SAMPLINGPARAMS']._serialized_start=113
42
- _globals['_SAMPLINGPARAMS']._serialized_end=850
43
- _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=769
44
- _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=817
45
- _globals['_DISAGGREGATEDPARAMS']._serialized_start=852
46
- _globals['_DISAGGREGATEDPARAMS']._serialized_end=945
47
- _globals['_GENERATEREQUEST']._serialized_start=948
48
- _globals['_GENERATEREQUEST']._serialized_end=1558
49
- _globals['_TOKENIZEDINPUT']._serialized_start=1560
50
- _globals['_TOKENIZEDINPUT']._serialized_end=1618
51
- _globals['_MULTIMODALINPUTS']._serialized_start=1621
52
- _globals['_MULTIMODALINPUTS']._serialized_end=1832
53
- _globals['_GENERATERESPONSE']._serialized_start=1835
54
- _globals['_GENERATERESPONSE']._serialized_end=2062
55
- _globals['_GENERATESTREAMCHUNK']._serialized_start=2065
56
- _globals['_GENERATESTREAMCHUNK']._serialized_end=2342
57
- _globals['_GENERATECOMPLETE']._serialized_start=2345
58
- _globals['_GENERATECOMPLETE']._serialized_end=2756
59
- _globals['_GENERATEERROR']._serialized_start=2758
60
- _globals['_GENERATEERROR']._serialized_end=2833
61
- _globals['_OUTPUTLOGPROBS']._serialized_start=2835
62
- _globals['_OUTPUTLOGPROBS']._serialized_end=2952
63
- _globals['_INPUTLOGPROBS']._serialized_start=2955
64
- _globals['_INPUTLOGPROBS']._serialized_end=3113
65
- _globals['_INPUTTOKENLOGPROB']._serialized_start=3115
66
- _globals['_INPUTTOKENLOGPROB']._serialized_end=3164
67
- _globals['_TOPLOGPROBS']._serialized_start=3166
68
- _globals['_TOPLOGPROBS']._serialized_end=3214
69
- _globals['_HIDDENSTATES']._serialized_start=3216
70
- _globals['_HIDDENSTATES']._serialized_end=3279
71
- _globals['_EMBEDREQUEST']._serialized_start=3282
72
- _globals['_EMBEDREQUEST']._serialized_end=3612
73
- _globals['_EMBEDRESPONSE']._serialized_start=3615
74
- _globals['_EMBEDRESPONSE']._serialized_end=3772
75
- _globals['_EMBEDCOMPLETE']._serialized_start=3775
76
- _globals['_EMBEDCOMPLETE']._serialized_end=3938
77
- _globals['_EMBEDDING']._serialized_start=3940
78
- _globals['_EMBEDDING']._serialized_end=3982
79
- _globals['_EMBEDERROR']._serialized_start=3984
80
- _globals['_EMBEDERROR']._serialized_end=4044
81
- _globals['_HEALTHCHECKREQUEST']._serialized_start=4046
82
- _globals['_HEALTHCHECKREQUEST']._serialized_end=4124
83
- _globals['_HEALTHCHECKRESPONSE']._serialized_start=4126
84
- _globals['_HEALTHCHECKRESPONSE']._serialized_end=4181
85
- _globals['_ABORTREQUEST']._serialized_start=4183
86
- _globals['_ABORTREQUEST']._serialized_end=4233
87
- _globals['_ABORTRESPONSE']._serialized_start=4235
88
- _globals['_ABORTRESPONSE']._serialized_end=4284
89
- _globals['_LOADLORAREQUEST']._serialized_start=4286
90
- _globals['_LOADLORAREQUEST']._serialized_end=4359
91
- _globals['_LOADLORARESPONSE']._serialized_start=4361
92
- _globals['_LOADLORARESPONSE']._serialized_end=4433
93
- _globals['_UNLOADLORAREQUEST']._serialized_start=4435
94
- _globals['_UNLOADLORAREQUEST']._serialized_end=4474
95
- _globals['_UNLOADLORARESPONSE']._serialized_start=4476
96
- _globals['_UNLOADLORARESPONSE']._serialized_end=4530
97
- _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4532
98
- _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4651
99
- _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4653
100
- _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4710
101
- _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4712
102
- _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4757
103
- _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4759
104
- _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4825
105
- _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4827
106
- _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4892
107
- _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4894
108
- _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4954
109
- _globals['_SGLANGSCHEDULER']._serialized_start=4957
110
- _globals['_SGLANGSCHEDULER']._serialized_end=5339
42
+ _globals['_SAMPLINGPARAMS']._serialized_end=833
43
+ _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=732
44
+ _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=780
45
+ _globals['_DISAGGREGATEDPARAMS']._serialized_start=835
46
+ _globals['_DISAGGREGATEDPARAMS']._serialized_end=928
47
+ _globals['_GENERATEREQUEST']._serialized_start=931
48
+ _globals['_GENERATEREQUEST']._serialized_end=1541
49
+ _globals['_TOKENIZEDINPUT']._serialized_start=1543
50
+ _globals['_TOKENIZEDINPUT']._serialized_end=1601
51
+ _globals['_MULTIMODALINPUTS']._serialized_start=1604
52
+ _globals['_MULTIMODALINPUTS']._serialized_end=1815
53
+ _globals['_GENERATERESPONSE']._serialized_start=1818
54
+ _globals['_GENERATERESPONSE']._serialized_end=2045
55
+ _globals['_GENERATESTREAMCHUNK']._serialized_start=2048
56
+ _globals['_GENERATESTREAMCHUNK']._serialized_end=2325
57
+ _globals['_GENERATECOMPLETE']._serialized_start=2328
58
+ _globals['_GENERATECOMPLETE']._serialized_end=2739
59
+ _globals['_GENERATEERROR']._serialized_start=2741
60
+ _globals['_GENERATEERROR']._serialized_end=2816
61
+ _globals['_OUTPUTLOGPROBS']._serialized_start=2818
62
+ _globals['_OUTPUTLOGPROBS']._serialized_end=2935
63
+ _globals['_INPUTLOGPROBS']._serialized_start=2938
64
+ _globals['_INPUTLOGPROBS']._serialized_end=3096
65
+ _globals['_INPUTTOKENLOGPROB']._serialized_start=3098
66
+ _globals['_INPUTTOKENLOGPROB']._serialized_end=3147
67
+ _globals['_TOPLOGPROBS']._serialized_start=3149
68
+ _globals['_TOPLOGPROBS']._serialized_end=3197
69
+ _globals['_HIDDENSTATES']._serialized_start=3199
70
+ _globals['_HIDDENSTATES']._serialized_end=3262
71
+ _globals['_EMBEDREQUEST']._serialized_start=3265
72
+ _globals['_EMBEDREQUEST']._serialized_end=3595
73
+ _globals['_EMBEDRESPONSE']._serialized_start=3598
74
+ _globals['_EMBEDRESPONSE']._serialized_end=3755
75
+ _globals['_EMBEDCOMPLETE']._serialized_start=3758
76
+ _globals['_EMBEDCOMPLETE']._serialized_end=3921
77
+ _globals['_EMBEDDING']._serialized_start=3923
78
+ _globals['_EMBEDDING']._serialized_end=3965
79
+ _globals['_EMBEDERROR']._serialized_start=3967
80
+ _globals['_EMBEDERROR']._serialized_end=4027
81
+ _globals['_HEALTHCHECKREQUEST']._serialized_start=4029
82
+ _globals['_HEALTHCHECKREQUEST']._serialized_end=4107
83
+ _globals['_HEALTHCHECKRESPONSE']._serialized_start=4109
84
+ _globals['_HEALTHCHECKRESPONSE']._serialized_end=4164
85
+ _globals['_ABORTREQUEST']._serialized_start=4166
86
+ _globals['_ABORTREQUEST']._serialized_end=4216
87
+ _globals['_ABORTRESPONSE']._serialized_start=4218
88
+ _globals['_ABORTRESPONSE']._serialized_end=4267
89
+ _globals['_LOADLORAREQUEST']._serialized_start=4269
90
+ _globals['_LOADLORAREQUEST']._serialized_end=4342
91
+ _globals['_LOADLORARESPONSE']._serialized_start=4344
92
+ _globals['_LOADLORARESPONSE']._serialized_end=4416
93
+ _globals['_UNLOADLORAREQUEST']._serialized_start=4418
94
+ _globals['_UNLOADLORAREQUEST']._serialized_end=4457
95
+ _globals['_UNLOADLORARESPONSE']._serialized_start=4459
96
+ _globals['_UNLOADLORARESPONSE']._serialized_end=4513
97
+ _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4515
98
+ _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4634
99
+ _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4636
100
+ _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4693
101
+ _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4695
102
+ _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4740
103
+ _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4742
104
+ _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4808
105
+ _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4810
106
+ _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4875
107
+ _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4877
108
+ _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4937
109
+ _globals['_GETMODELINFOREQUEST']._serialized_start=4939
110
+ _globals['_GETMODELINFOREQUEST']._serialized_end=4960
111
+ _globals['_GETMODELINFORESPONSE']._serialized_start=4963
112
+ _globals['_GETMODELINFORESPONSE']._serialized_end=5325
113
+ _globals['_GETSERVERINFOREQUEST']._serialized_start=5327
114
+ _globals['_GETSERVERINFOREQUEST']._serialized_end=5349
115
+ _globals['_GETSERVERINFORESPONSE']._serialized_start=5352
116
+ _globals['_GETSERVERINFORESPONSE']._serialized_end=5663
117
+ _globals['_SGLANGSCHEDULER']._serialized_start=5666
118
+ _globals['_SGLANGSCHEDULER']._serialized_end=6261
111
119
  # @@protoc_insertion_point(module_scope)
@@ -11,7 +11,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
11
11
  DESCRIPTOR: _descriptor.FileDescriptor
12
12
 
13
13
  class SamplingParams(_message.Message):
14
- __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
14
+ __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "n", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
15
15
  class LogitBiasEntry(_message.Message):
16
16
  __slots__ = ("key", "value")
17
17
  KEY_FIELD_NUMBER: _ClassVar[int]
@@ -35,9 +35,7 @@ class SamplingParams(_message.Message):
35
35
  JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
36
36
  EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
37
37
  STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
38
- LORA_PATH_FIELD_NUMBER: _ClassVar[int]
39
38
  N_FIELD_NUMBER: _ClassVar[int]
40
- TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
41
39
  MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
42
40
  IGNORE_EOS_FIELD_NUMBER: _ClassVar[int]
43
41
  NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
@@ -60,16 +58,14 @@ class SamplingParams(_message.Message):
60
58
  json_schema: str
61
59
  ebnf_grammar: str
62
60
  structural_tag: str
63
- lora_path: str
64
61
  n: int
65
- token_healing: bool
66
62
  min_new_tokens: int
67
63
  ignore_eos: bool
68
64
  no_stop_trim: bool
69
65
  stream_interval: int
70
66
  logit_bias: _containers.ScalarMap[str, float]
71
67
  custom_params: _struct_pb2.Struct
72
- def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
68
+ def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., n: _Optional[int] = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
73
69
 
74
70
  class DisaggregatedParams(_message.Message):
75
71
  __slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
@@ -432,3 +428,65 @@ class SetInternalStateResponse(_message.Message):
432
428
  success: bool
433
429
  message: str
434
430
  def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
431
+
432
+ class GetModelInfoRequest(_message.Message):
433
+ __slots__ = ()
434
+ def __init__(self) -> None: ...
435
+
436
+ class GetModelInfoResponse(_message.Message):
437
+ __slots__ = ("model_path", "tokenizer_path", "is_generation", "preferred_sampling_params", "weight_version", "served_model_name", "max_context_length", "vocab_size", "supports_vision", "model_type", "eos_token_ids", "pad_token_id", "bos_token_id", "max_req_input_len")
438
+ MODEL_PATH_FIELD_NUMBER: _ClassVar[int]
439
+ TOKENIZER_PATH_FIELD_NUMBER: _ClassVar[int]
440
+ IS_GENERATION_FIELD_NUMBER: _ClassVar[int]
441
+ PREFERRED_SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
442
+ WEIGHT_VERSION_FIELD_NUMBER: _ClassVar[int]
443
+ SERVED_MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
444
+ MAX_CONTEXT_LENGTH_FIELD_NUMBER: _ClassVar[int]
445
+ VOCAB_SIZE_FIELD_NUMBER: _ClassVar[int]
446
+ SUPPORTS_VISION_FIELD_NUMBER: _ClassVar[int]
447
+ MODEL_TYPE_FIELD_NUMBER: _ClassVar[int]
448
+ EOS_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
449
+ PAD_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
450
+ BOS_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
451
+ MAX_REQ_INPUT_LEN_FIELD_NUMBER: _ClassVar[int]
452
+ model_path: str
453
+ tokenizer_path: str
454
+ is_generation: bool
455
+ preferred_sampling_params: str
456
+ weight_version: str
457
+ served_model_name: str
458
+ max_context_length: int
459
+ vocab_size: int
460
+ supports_vision: bool
461
+ model_type: str
462
+ eos_token_ids: _containers.RepeatedScalarFieldContainer[int]
463
+ pad_token_id: int
464
+ bos_token_id: int
465
+ max_req_input_len: int
466
+ def __init__(self, model_path: _Optional[str] = ..., tokenizer_path: _Optional[str] = ..., is_generation: bool = ..., preferred_sampling_params: _Optional[str] = ..., weight_version: _Optional[str] = ..., served_model_name: _Optional[str] = ..., max_context_length: _Optional[int] = ..., vocab_size: _Optional[int] = ..., supports_vision: bool = ..., model_type: _Optional[str] = ..., eos_token_ids: _Optional[_Iterable[int]] = ..., pad_token_id: _Optional[int] = ..., bos_token_id: _Optional[int] = ..., max_req_input_len: _Optional[int] = ...) -> None: ...
467
+
468
+ class GetServerInfoRequest(_message.Message):
469
+ __slots__ = ()
470
+ def __init__(self) -> None: ...
471
+
472
+ class GetServerInfoResponse(_message.Message):
473
+ __slots__ = ("server_args", "scheduler_info", "active_requests", "is_paused", "last_receive_timestamp", "uptime_seconds", "sglang_version", "server_type", "start_time")
474
+ SERVER_ARGS_FIELD_NUMBER: _ClassVar[int]
475
+ SCHEDULER_INFO_FIELD_NUMBER: _ClassVar[int]
476
+ ACTIVE_REQUESTS_FIELD_NUMBER: _ClassVar[int]
477
+ IS_PAUSED_FIELD_NUMBER: _ClassVar[int]
478
+ LAST_RECEIVE_TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
479
+ UPTIME_SECONDS_FIELD_NUMBER: _ClassVar[int]
480
+ SGLANG_VERSION_FIELD_NUMBER: _ClassVar[int]
481
+ SERVER_TYPE_FIELD_NUMBER: _ClassVar[int]
482
+ START_TIME_FIELD_NUMBER: _ClassVar[int]
483
+ server_args: _struct_pb2.Struct
484
+ scheduler_info: _struct_pb2.Struct
485
+ active_requests: int
486
+ is_paused: bool
487
+ last_receive_timestamp: float
488
+ uptime_seconds: float
489
+ sglang_version: str
490
+ server_type: str
491
+ start_time: _timestamp_pb2.Timestamp
492
+ def __init__(self, server_args: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., scheduler_info: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., active_requests: _Optional[int] = ..., is_paused: bool = ..., last_receive_timestamp: _Optional[float] = ..., uptime_seconds: _Optional[float] = ..., sglang_version: _Optional[str] = ..., server_type: _Optional[str] = ..., start_time: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ...
@@ -59,6 +59,16 @@ class SglangSchedulerStub(object):
59
59
  request_serializer=sglang__scheduler__pb2.AbortRequest.SerializeToString,
60
60
  response_deserializer=sglang__scheduler__pb2.AbortResponse.FromString,
61
61
  _registered_method=True)
62
+ self.GetModelInfo = channel.unary_unary(
63
+ '/sglang.grpc.scheduler.SglangScheduler/GetModelInfo',
64
+ request_serializer=sglang__scheduler__pb2.GetModelInfoRequest.SerializeToString,
65
+ response_deserializer=sglang__scheduler__pb2.GetModelInfoResponse.FromString,
66
+ _registered_method=True)
67
+ self.GetServerInfo = channel.unary_unary(
68
+ '/sglang.grpc.scheduler.SglangScheduler/GetServerInfo',
69
+ request_serializer=sglang__scheduler__pb2.GetServerInfoRequest.SerializeToString,
70
+ response_deserializer=sglang__scheduler__pb2.GetServerInfoResponse.FromString,
71
+ _registered_method=True)
62
72
 
63
73
 
64
74
  class SglangSchedulerServicer(object):
@@ -94,6 +104,20 @@ class SglangSchedulerServicer(object):
94
104
  context.set_details('Method not implemented!')
95
105
  raise NotImplementedError('Method not implemented!')
96
106
 
107
+ def GetModelInfo(self, request, context):
108
+ """Get model information
109
+ """
110
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
111
+ context.set_details('Method not implemented!')
112
+ raise NotImplementedError('Method not implemented!')
113
+
114
+ def GetServerInfo(self, request, context):
115
+ """Get server information
116
+ """
117
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
118
+ context.set_details('Method not implemented!')
119
+ raise NotImplementedError('Method not implemented!')
120
+
97
121
 
98
122
  def add_SglangSchedulerServicer_to_server(servicer, server):
99
123
  rpc_method_handlers = {
@@ -117,6 +141,16 @@ def add_SglangSchedulerServicer_to_server(servicer, server):
117
141
  request_deserializer=sglang__scheduler__pb2.AbortRequest.FromString,
118
142
  response_serializer=sglang__scheduler__pb2.AbortResponse.SerializeToString,
119
143
  ),
144
+ 'GetModelInfo': grpc.unary_unary_rpc_method_handler(
145
+ servicer.GetModelInfo,
146
+ request_deserializer=sglang__scheduler__pb2.GetModelInfoRequest.FromString,
147
+ response_serializer=sglang__scheduler__pb2.GetModelInfoResponse.SerializeToString,
148
+ ),
149
+ 'GetServerInfo': grpc.unary_unary_rpc_method_handler(
150
+ servicer.GetServerInfo,
151
+ request_deserializer=sglang__scheduler__pb2.GetServerInfoRequest.FromString,
152
+ response_serializer=sglang__scheduler__pb2.GetServerInfoResponse.SerializeToString,
153
+ ),
120
154
  }
121
155
  generic_handler = grpc.method_handlers_generic_handler(
122
156
  'sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
@@ -237,3 +271,57 @@ class SglangScheduler(object):
237
271
  timeout,
238
272
  metadata,
239
273
  _registered_method=True)
274
+
275
+ @staticmethod
276
+ def GetModelInfo(request,
277
+ target,
278
+ options=(),
279
+ channel_credentials=None,
280
+ call_credentials=None,
281
+ insecure=False,
282
+ compression=None,
283
+ wait_for_ready=None,
284
+ timeout=None,
285
+ metadata=None):
286
+ return grpc.experimental.unary_unary(
287
+ request,
288
+ target,
289
+ '/sglang.grpc.scheduler.SglangScheduler/GetModelInfo',
290
+ sglang__scheduler__pb2.GetModelInfoRequest.SerializeToString,
291
+ sglang__scheduler__pb2.GetModelInfoResponse.FromString,
292
+ options,
293
+ channel_credentials,
294
+ insecure,
295
+ call_credentials,
296
+ compression,
297
+ wait_for_ready,
298
+ timeout,
299
+ metadata,
300
+ _registered_method=True)
301
+
302
+ @staticmethod
303
+ def GetServerInfo(request,
304
+ target,
305
+ options=(),
306
+ channel_credentials=None,
307
+ call_credentials=None,
308
+ insecure=False,
309
+ compression=None,
310
+ wait_for_ready=None,
311
+ timeout=None,
312
+ metadata=None):
313
+ return grpc.experimental.unary_unary(
314
+ request,
315
+ target,
316
+ '/sglang.grpc.scheduler.SglangScheduler/GetServerInfo',
317
+ sglang__scheduler__pb2.GetServerInfoRequest.SerializeToString,
318
+ sglang__scheduler__pb2.GetServerInfoResponse.FromString,
319
+ options,
320
+ channel_credentials,
321
+ insecure,
322
+ call_credentials,
323
+ compression,
324
+ wait_for_ready,
325
+ timeout,
326
+ metadata,
327
+ _registered_method=True)
@@ -1,7 +1,14 @@
1
1
  import logging
2
+ from typing import TYPE_CHECKING
2
3
 
3
4
  logger = logging.getLogger(__name__)
4
5
 
6
+
7
+ if TYPE_CHECKING:
8
+ # evade circular imports
9
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
10
+ from sglang.srt.model_executor.model_runner import ModelRunner
11
+
5
12
  ATTENTION_BACKENDS = {}
6
13
 
7
14
 
@@ -129,9 +136,6 @@ def create_flashattention_v3_backend(runner):
129
136
 
130
137
  @register_attention_backend("fa4")
131
138
  def create_flashattention_v4_backend(runner):
132
- assert (
133
- runner.use_mla_backend
134
- ), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
135
139
  from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
136
140
 
137
141
  return FlashAttentionBackend(runner, fa_impl_ver=4)
@@ -169,36 +173,41 @@ def create_dual_chunk_flash_attn_backend(runner):
169
173
  return DualChunkFlashAttentionBackend(runner)
170
174
 
171
175
 
172
- def attn_backend_wrapper(runner, full_attn_backend):
176
+ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
173
177
  """
174
178
  Wrapper for special models like hybrid GDN, so we don't
175
179
  need to change the code of the original attention backend.
176
180
  """
177
181
  assert not (
178
- runner.is_hybrid_gdn and runner.use_mla_backend
182
+ runner.hybrid_gdn_config is not None and runner.use_mla_backend
179
183
  ), "hybrid_gdn can only be used with non-MLA models."
180
184
 
181
- # wrap for hybrid GDN models
182
- if runner.is_hybrid_gdn:
183
- from sglang.srt.utils import is_blackwell, is_npu
184
-
185
- if is_blackwell():
186
- assert (
187
- runner.server_args.attention_backend == "triton"
188
- or runner.server_args.attention_backend == "trtllm_mha"
189
- ), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
190
- if is_npu():
191
- assert (
192
- runner.server_args.attention_backend == "ascend"
193
- ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
194
- logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
185
+ if cfg := runner.mambaish_config:
195
186
  from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
187
+ GDNAttnBackend,
196
188
  HybridLinearAttnBackend,
197
- MambaAttnBackend,
189
+ Mamba2AttnBackend,
198
190
  )
191
+ from sglang.srt.utils import is_blackwell, is_npu
199
192
 
200
- linear_attn_backend = MambaAttnBackend(runner)
201
- full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
193
+ if runner.hybrid_gdn_config is not None:
194
+ if is_blackwell():
195
+ assert (
196
+ runner.server_args.attention_backend == "triton"
197
+ ), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
198
+ if is_npu():
199
+ assert (
200
+ runner.server_args.attention_backend == "ascend"
201
+ ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
202
+ logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
203
+ linear_attn_backend = GDNAttnBackend(runner)
204
+ elif runner.mamba2_config is not None:
205
+ linear_attn_backend = Mamba2AttnBackend(runner)
206
+ else:
207
+ raise ValueError(
208
+ "Expected hybrid GDN or NemotronH models, but got unknown model."
209
+ )
210
+ full_attn_layers = cfg.full_attention_layer_ids
202
211
  return HybridLinearAttnBackend(
203
212
  full_attn_backend, linear_attn_backend, full_attn_layers
204
213
  )
@@ -181,6 +181,45 @@ def _layer_norm_fwd(
181
181
  return out, mean, rstd
182
182
 
183
183
 
184
+ def rms_norm_gated(
185
+ *,
186
+ x,
187
+ weight,
188
+ bias,
189
+ z=None,
190
+ eps=1e-6,
191
+ group_size=None,
192
+ norm_before_gate=True,
193
+ is_rms_norm=False,
194
+ ):
195
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
196
+
197
+ x_shape_og = x.shape
198
+ # reshape input data into 2D tensor
199
+ x = x.reshape(-1, x.shape[-1])
200
+ if x.stride(-1) != 1:
201
+ x = x.contiguous()
202
+ if z is not None:
203
+ assert z.shape == x_shape_og
204
+ z = z.reshape(-1, z.shape[-1])
205
+ if z.stride(-1) != 1:
206
+ z = z.contiguous()
207
+ weight = weight.contiguous()
208
+ if bias is not None:
209
+ bias = bias.contiguous()
210
+ y, mean, rstd = _layer_norm_fwd(
211
+ x,
212
+ weight,
213
+ bias,
214
+ eps,
215
+ z=z,
216
+ group_size=group_size,
217
+ norm_before_gate=norm_before_gate,
218
+ is_rms_norm=is_rms_norm,
219
+ )
220
+ return y.reshape(x_shape_og)
221
+
222
+
184
223
  class LayerNormFn(torch.autograd.Function):
185
224
 
186
225
  @staticmethod
@@ -195,32 +234,16 @@ class LayerNormFn(torch.autograd.Function):
195
234
  norm_before_gate=True,
196
235
  is_rms_norm=False,
197
236
  ):
198
- """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
199
-
200
- x_shape_og = x.shape
201
- # reshape input data into 2D tensor
202
- x = x.reshape(-1, x.shape[-1])
203
- if x.stride(-1) != 1:
204
- x = x.contiguous()
205
- if z is not None:
206
- assert z.shape == x_shape_og
207
- z = z.reshape(-1, z.shape[-1])
208
- if z.stride(-1) != 1:
209
- z = z.contiguous()
210
- weight = weight.contiguous()
211
- if bias is not None:
212
- bias = bias.contiguous()
213
- y, mean, rstd = _layer_norm_fwd(
214
- x,
215
- weight,
216
- bias,
217
- eps,
237
+ return rms_norm_gated(
238
+ x=x,
239
+ weight=weight,
240
+ bias=bias,
241
+ eps=eps,
218
242
  z=z,
219
243
  group_size=group_size,
220
244
  norm_before_gate=norm_before_gate,
221
245
  is_rms_norm=is_rms_norm,
222
246
  )
223
- return y.reshape(x_shape_og)
224
247
 
225
248
 
226
249
  def layernorm_fn(
@@ -238,14 +261,6 @@ def layernorm_fn(
238
261
  )
239
262
 
240
263
 
241
- def rmsnorm_fn(
242
- x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
243
- ):
244
- return LayerNormFn.apply(
245
- x, weight, bias, z, eps, group_size, norm_before_gate, True
246
- )
247
-
248
-
249
264
  class LayerNorm(torch.nn.Module):
250
265
 
251
266
  def __init__(
@@ -284,6 +299,7 @@ class LayerNorm(torch.nn.Module):
284
299
  group_size=self.group_size,
285
300
  eps=self.eps,
286
301
  norm_before_gate=self.norm_before_gate,
302
+ is_rms_norm=False,
287
303
  )
288
304
 
289
305
 
@@ -315,7 +331,7 @@ class RMSNorm(torch.nn.Module):
315
331
 
316
332
  def forward(self, x, z=None):
317
333
  """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
318
- return rmsnorm_fn(
334
+ return layernorm_fn(
319
335
  x,
320
336
  self.weight,
321
337
  self.bias,
@@ -323,4 +339,5 @@ class RMSNorm(torch.nn.Module):
323
339
  eps=self.eps,
324
340
  group_size=self.group_size,
325
341
  norm_before_gate=self.norm_before_gate,
342
+ is_rms_norm=True,
326
343
  )
@@ -754,7 +754,6 @@ class FlashAttentionBackend(AttentionBackend):
754
754
 
755
755
  # Use Flash Attention for prefill
756
756
  if not self.use_mla:
757
- assert self.fa_impl_ver in [3], "Only FA3 support here"
758
757
  # Do multi-head attention
759
758
  key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
760
759
  layer.layer_id