sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -38,8 +38,10 @@ from sglang.srt.configs import (
|
|
38
38
|
ChatGLMConfig,
|
39
39
|
DbrxConfig,
|
40
40
|
DeepseekVL2Config,
|
41
|
+
DotsOCRConfig,
|
41
42
|
DotsVLMConfig,
|
42
43
|
ExaoneConfig,
|
44
|
+
FalconH1Config,
|
43
45
|
KimiVLConfig,
|
44
46
|
LongcatFlashConfig,
|
45
47
|
MultiModalityConfig,
|
@@ -61,7 +63,9 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|
61
63
|
Step3VLConfig.model_type: Step3VLConfig,
|
62
64
|
LongcatFlashConfig.model_type: LongcatFlashConfig,
|
63
65
|
Qwen3NextConfig.model_type: Qwen3NextConfig,
|
66
|
+
FalconH1Config.model_type: FalconH1Config,
|
64
67
|
DotsVLMConfig.model_type: DotsVLMConfig,
|
68
|
+
DotsOCRConfig.model_type: DotsOCRConfig,
|
65
69
|
}
|
66
70
|
|
67
71
|
for name, cls in _CONFIG_REGISTRY.items():
|
@@ -119,6 +123,38 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
119
123
|
return config
|
120
124
|
|
121
125
|
|
126
|
+
# Temporary hack for DeepSeek-V3.2 model
|
127
|
+
def _load_deepseek_v32_model(
|
128
|
+
model_path: str,
|
129
|
+
trust_remote_code: bool = False,
|
130
|
+
revision: Optional[str] = None,
|
131
|
+
**kwargs,
|
132
|
+
):
|
133
|
+
# first get the local path
|
134
|
+
local_path = download_from_hf(model_path)
|
135
|
+
# then load the config file in json
|
136
|
+
config_file = os.path.join(local_path, "config.json")
|
137
|
+
if not os.path.exists(config_file):
|
138
|
+
raise RuntimeError(f"Can't find config file in {local_path}.")
|
139
|
+
|
140
|
+
with open(config_file, "r") as f:
|
141
|
+
config_json = json.load(f)
|
142
|
+
|
143
|
+
config_json["architectures"] = ["DeepseekV3ForCausalLM"]
|
144
|
+
config_json["model_type"] = "deepseek_v3"
|
145
|
+
|
146
|
+
tmp_path = os.path.join(local_path, "_tmp_config_folder")
|
147
|
+
os.makedirs(tmp_path, exist_ok=True)
|
148
|
+
|
149
|
+
unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
|
150
|
+
with open(unique_path, "w") as f:
|
151
|
+
json.dump(config_json, f)
|
152
|
+
|
153
|
+
return AutoConfig.from_pretrained(
|
154
|
+
unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
155
|
+
)
|
156
|
+
|
157
|
+
|
122
158
|
@lru_cache_frozenset(maxsize=32)
|
123
159
|
def get_config(
|
124
160
|
model: str,
|
@@ -140,9 +176,17 @@ def get_config(
|
|
140
176
|
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
141
177
|
model = client.get_local_dir()
|
142
178
|
|
143
|
-
|
144
|
-
|
145
|
-
|
179
|
+
try:
|
180
|
+
config = AutoConfig.from_pretrained(
|
181
|
+
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
182
|
+
)
|
183
|
+
except ValueError as e:
|
184
|
+
if not "deepseek_v32" in str(e):
|
185
|
+
raise e
|
186
|
+
config = _load_deepseek_v32_model(
|
187
|
+
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
188
|
+
)
|
189
|
+
|
146
190
|
if (
|
147
191
|
config.architectures is not None
|
148
192
|
and config.architectures[0] == "Phi4MMForCausalLM"
|
@@ -374,8 +418,8 @@ def get_processor(
|
|
374
418
|
**kwargs,
|
375
419
|
)
|
376
420
|
|
377
|
-
# fix: for Qwen2-VL
|
378
|
-
if config.model_type in {"qwen2_vl"}:
|
421
|
+
# fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided.
|
422
|
+
if config.model_type in {"qwen2_vl", "sarashina2_vision"}:
|
379
423
|
if "size" not in kwargs:
|
380
424
|
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
|
381
425
|
|
@@ -17,10 +17,18 @@ import torch
|
|
17
17
|
from packaging import version
|
18
18
|
from torch.multiprocessing import reductions
|
19
19
|
|
20
|
+
from sglang.srt.utils import is_npu
|
21
|
+
|
22
|
+
_is_npu = is_npu()
|
23
|
+
|
20
24
|
|
21
25
|
def monkey_patch_torch_reductions():
|
22
26
|
"""Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
|
23
27
|
|
28
|
+
# Currently, NPU does not support UUID. This has been temporarily commented out, with support expected in the fourth quarter.
|
29
|
+
if _is_npu:
|
30
|
+
return
|
31
|
+
|
24
32
|
if hasattr(reductions, "_reduce_tensor_original"):
|
25
33
|
return
|
26
34
|
|
@@ -0,0 +1,452 @@
|
|
1
|
+
# https://raw.githubusercontent.com/ROCm/rocmProfileData/refs/heads/master/tools/rpd2tracing.py
|
2
|
+
# commit 92d13a08328625463e9ba944cece82fc5eea36e6
|
3
|
+
def rpd_to_chrome_trace(
|
4
|
+
input_rpd, output_json=None, start="0%", end="100%", format="object"
|
5
|
+
):
|
6
|
+
import gzip
|
7
|
+
import sqlite3
|
8
|
+
|
9
|
+
if output_json is None:
|
10
|
+
import pathlib
|
11
|
+
|
12
|
+
output_json = pathlib.PurePath(input_rpd).with_suffix(".trace.json.gz")
|
13
|
+
|
14
|
+
connection = sqlite3.connect(input_rpd)
|
15
|
+
|
16
|
+
outfile = gzip.open(output_json, "wt", encoding="utf-8")
|
17
|
+
|
18
|
+
if format == "object":
|
19
|
+
outfile.write('{"traceEvents": ')
|
20
|
+
|
21
|
+
outfile.write("[ {}\n")
|
22
|
+
|
23
|
+
for row in connection.execute("select distinct gpuId from rocpd_op"):
|
24
|
+
try:
|
25
|
+
outfile.write(
|
26
|
+
',{"name": "process_name", "ph": "M", "pid":"%s","args":{"name":"%s"}}\n'
|
27
|
+
% (row[0], "GPU" + str(row[0]))
|
28
|
+
)
|
29
|
+
outfile.write(
|
30
|
+
',{"name": "process_sort_index", "ph": "M", "pid":"%s","args":{"sort_index":"%s"}}\n'
|
31
|
+
% (row[0], row[0] + 1000000)
|
32
|
+
)
|
33
|
+
except ValueError:
|
34
|
+
outfile.write("")
|
35
|
+
|
36
|
+
for row in connection.execute("select distinct pid, tid from rocpd_api"):
|
37
|
+
try:
|
38
|
+
outfile.write(
|
39
|
+
',{"name":"thread_name","ph":"M","pid":"%s","tid":"%s","args":{"name":"%s"}}\n'
|
40
|
+
% (row[0], row[1], "Hip " + str(row[1]))
|
41
|
+
)
|
42
|
+
outfile.write(
|
43
|
+
',{"name":"thread_sort_index","ph":"M","pid":"%s","tid":"%s","args":{"sort_index":"%s"}}\n'
|
44
|
+
% (row[0], row[1], row[1] * 2)
|
45
|
+
)
|
46
|
+
except ValueError:
|
47
|
+
outfile.write("")
|
48
|
+
|
49
|
+
try:
|
50
|
+
# FIXME - these aren't rendering correctly in chrome://tracing
|
51
|
+
for row in connection.execute("select distinct pid, tid from rocpd_hsaApi"):
|
52
|
+
try:
|
53
|
+
outfile.write(
|
54
|
+
',{"name":"thread_name","ph":"M","pid":"%s","tid":"%s","args":{"name":"%s"}}\n'
|
55
|
+
% (row[0], row[1], "HSA " + str(row[1]))
|
56
|
+
)
|
57
|
+
outfile.write(
|
58
|
+
',{"name":"thread_sort_index","ph":"M","pid":"%s","tid":"%s","args":{"sort_index":"%s"}}\n'
|
59
|
+
% (row[0], row[1], row[1] * 2 - 1)
|
60
|
+
)
|
61
|
+
except ValueError:
|
62
|
+
outfile.write("")
|
63
|
+
except:
|
64
|
+
pass
|
65
|
+
|
66
|
+
rangeStringApi = ""
|
67
|
+
rangeStringOp = ""
|
68
|
+
rangeStringMonitor = ""
|
69
|
+
min_time = connection.execute("select MIN(start) from rocpd_api;").fetchall()[0][0]
|
70
|
+
max_time = connection.execute("select MAX(end) from rocpd_api;").fetchall()[0][0]
|
71
|
+
if min_time == None:
|
72
|
+
raise Exception("Trace file is empty.")
|
73
|
+
|
74
|
+
print("Timestamps:")
|
75
|
+
print(f"\t first: \t{min_time/1000} us")
|
76
|
+
print(f"\t last: \t{max_time/1000} us")
|
77
|
+
print(f"\t duration: \t{(max_time-min_time) / 1000000000} seconds")
|
78
|
+
|
79
|
+
start_time = min_time / 1000
|
80
|
+
end_time = max_time / 1000
|
81
|
+
|
82
|
+
if start:
|
83
|
+
if "%" in start:
|
84
|
+
start_time = (
|
85
|
+
(max_time - min_time) * (int(start.replace("%", "")) / 100) + min_time
|
86
|
+
) / 1000
|
87
|
+
else:
|
88
|
+
start_time = int(start)
|
89
|
+
rangeStringApi = "where rocpd_api.start/1000 >= %s" % (start_time)
|
90
|
+
rangeStringOp = "where rocpd_op.start/1000 >= %s" % (start_time)
|
91
|
+
rangeStringMonitor = "where start/1000 >= %s" % (start_time)
|
92
|
+
if end:
|
93
|
+
if "%" in end:
|
94
|
+
end_time = (
|
95
|
+
(max_time - min_time) * (int(end.replace("%", "")) / 100) + min_time
|
96
|
+
) / 1000
|
97
|
+
else:
|
98
|
+
end_time = int(end)
|
99
|
+
|
100
|
+
rangeStringApi = (
|
101
|
+
rangeStringApi + " and rocpd_api.start/1000 <= %s" % (end_time)
|
102
|
+
if start != None
|
103
|
+
else "where rocpd_api.start/1000 <= %s" % (end_time)
|
104
|
+
)
|
105
|
+
rangeStringOp = (
|
106
|
+
rangeStringOp + " and rocpd_op.start/1000 <= %s" % (end_time)
|
107
|
+
if start != None
|
108
|
+
else "where rocpd_op.start/1000 <= %s" % (end_time)
|
109
|
+
)
|
110
|
+
rangeStringMonitor = (
|
111
|
+
rangeStringMonitor + " and start/1000 <= %s" % (end_time)
|
112
|
+
if start != None
|
113
|
+
else "where start/1000 <= %s" % (end_time)
|
114
|
+
)
|
115
|
+
|
116
|
+
print("\nFilter: %s" % (rangeStringApi))
|
117
|
+
print(f"Output duration: {(end_time-start_time)/1000000} seconds")
|
118
|
+
|
119
|
+
# Output Ops
|
120
|
+
|
121
|
+
for row in connection.execute(
|
122
|
+
"select A.string as optype, B.string as description, gpuId, queueId, rocpd_op.start/1000.0, (rocpd_op.end-rocpd_op.start) / 1000.0 from rocpd_op INNER JOIN rocpd_string A on A.id = rocpd_op.opType_id INNER Join rocpd_string B on B.id = rocpd_op.description_id %s"
|
123
|
+
% (rangeStringOp)
|
124
|
+
):
|
125
|
+
try:
|
126
|
+
name = row[0] if len(row[1]) == 0 else row[1]
|
127
|
+
outfile.write(
|
128
|
+
',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
|
129
|
+
% (row[2], row[3], name, row[4], row[5], row[0])
|
130
|
+
)
|
131
|
+
except ValueError:
|
132
|
+
outfile.write("")
|
133
|
+
|
134
|
+
# Output Graph executions on GPU
|
135
|
+
try:
|
136
|
+
for row in connection.execute(
|
137
|
+
"select graphExec, gpuId, queueId, min(start)/1000.0, (max(end)-min(start))/1000.0, count(*) from rocpd_graphLaunchapi A join rocpd_api_ops B on B.api_id = A.api_ptr_id join rocpd_op C on C.id = B.op_id %s group by api_ptr_id"
|
138
|
+
% (rangeStringMonitor)
|
139
|
+
):
|
140
|
+
try:
|
141
|
+
outfile.write(
|
142
|
+
',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"kernels":"%s"}}\n'
|
143
|
+
% (row[1], row[2], f"Graph {row[0]}", row[3], row[4], row[5])
|
144
|
+
)
|
145
|
+
except ValueError:
|
146
|
+
outfile.write("")
|
147
|
+
except:
|
148
|
+
pass
|
149
|
+
|
150
|
+
# Output apis
|
151
|
+
for row in connection.execute(
|
152
|
+
"select A.string as apiName, B.string as args, pid, tid, rocpd_api.start/1000.0, (rocpd_api.end-rocpd_api.start) / 1000.0, (rocpd_api.end != rocpd_api.start) as has_duration from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id INNER Join rocpd_string B on B.id = rocpd_api.args_id %s order by rocpd_api.id"
|
153
|
+
% (rangeStringApi)
|
154
|
+
):
|
155
|
+
try:
|
156
|
+
if row[0] == "UserMarker":
|
157
|
+
if row[6] == 0: # instantanuous "mark" messages
|
158
|
+
outfile.write(
|
159
|
+
',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","ph":"i","s":"p","args":{"desc":"%s"}}\n'
|
160
|
+
% (
|
161
|
+
row[2],
|
162
|
+
row[3],
|
163
|
+
row[1].replace('"', ""),
|
164
|
+
row[4],
|
165
|
+
row[1].replace('"', ""),
|
166
|
+
)
|
167
|
+
)
|
168
|
+
else:
|
169
|
+
outfile.write(
|
170
|
+
',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
|
171
|
+
% (
|
172
|
+
row[2],
|
173
|
+
row[3],
|
174
|
+
row[1].replace('"', ""),
|
175
|
+
row[4],
|
176
|
+
row[5],
|
177
|
+
row[1].replace('"', ""),
|
178
|
+
)
|
179
|
+
)
|
180
|
+
else:
|
181
|
+
outfile.write(
|
182
|
+
',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
|
183
|
+
% (
|
184
|
+
row[2],
|
185
|
+
row[3],
|
186
|
+
row[0],
|
187
|
+
row[4],
|
188
|
+
row[5],
|
189
|
+
row[1].replace('"', "").replace("\t", ""),
|
190
|
+
)
|
191
|
+
)
|
192
|
+
except ValueError:
|
193
|
+
outfile.write("")
|
194
|
+
|
195
|
+
# Output api->op linkage
|
196
|
+
for row in connection.execute(
|
197
|
+
"select rocpd_api_ops.id, pid, tid, gpuId, queueId, rocpd_api.end/1000.0 - 2, rocpd_op.start/1000.0 from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id %s"
|
198
|
+
% (rangeStringApi)
|
199
|
+
):
|
200
|
+
try:
|
201
|
+
fromtime = row[5] if row[5] < row[6] else row[6]
|
202
|
+
outfile.write(
|
203
|
+
',{"pid":"%s","tid":"%s","cat":"api_op","name":"api_op","ts":"%s","id":"%s","ph":"s"}\n'
|
204
|
+
% (row[1], row[2], fromtime, row[0])
|
205
|
+
)
|
206
|
+
outfile.write(
|
207
|
+
',{"pid":"%s","tid":"%s","cat":"api_op","name":"api_op","ts":"%s","id":"%s","ph":"f", "bp":"e"}\n'
|
208
|
+
% (row[3], row[4], row[6], row[0])
|
209
|
+
)
|
210
|
+
except ValueError:
|
211
|
+
outfile.write("")
|
212
|
+
|
213
|
+
try:
|
214
|
+
for row in connection.execute(
|
215
|
+
"select A.string as apiName, B.string as args, pid, tid, rocpd_hsaApi.start/1000.0, (rocpd_hsaApi.end-rocpd_hsaApi.start) / 1000.0 from rocpd_hsaApi INNER JOIN rocpd_string A on A.id = rocpd_hsaApi.apiName_id INNER Join rocpd_string B on B.id = rocpd_hsaApi.args_id %s order by rocpd_hsaApi.id"
|
216
|
+
% (rangeStringApi)
|
217
|
+
):
|
218
|
+
try:
|
219
|
+
outfile.write(
|
220
|
+
',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
|
221
|
+
% (
|
222
|
+
row[2],
|
223
|
+
row[3] + 1,
|
224
|
+
row[0],
|
225
|
+
row[4],
|
226
|
+
row[5],
|
227
|
+
row[1].replace('"', ""),
|
228
|
+
)
|
229
|
+
)
|
230
|
+
except ValueError:
|
231
|
+
outfile.write("")
|
232
|
+
except:
|
233
|
+
pass
|
234
|
+
|
235
|
+
#
|
236
|
+
# Counters
|
237
|
+
#
|
238
|
+
|
239
|
+
# Counters should extend to the last event in the trace. This means they need to have a value at Tend.
|
240
|
+
# Figure out when that is
|
241
|
+
|
242
|
+
T_end = 0
|
243
|
+
for row in connection.execute(
|
244
|
+
"SELECT max(end)/1000 from (SELECT end from rocpd_api UNION ALL SELECT end from rocpd_op)"
|
245
|
+
):
|
246
|
+
T_end = int(row[0])
|
247
|
+
if end:
|
248
|
+
T_end = end_time
|
249
|
+
|
250
|
+
# Loop over GPU for per-gpu counters
|
251
|
+
gpuIdsPresent = []
|
252
|
+
for row in connection.execute("SELECT DISTINCT gpuId FROM rocpd_op"):
|
253
|
+
gpuIdsPresent.append(row[0])
|
254
|
+
|
255
|
+
for gpuId in gpuIdsPresent:
|
256
|
+
# print(f"Creating counters for: {gpuId}")
|
257
|
+
|
258
|
+
# Create the queue depth counter
|
259
|
+
depth = 0
|
260
|
+
idle = 1
|
261
|
+
for row in connection.execute(
|
262
|
+
'select * from (select rocpd_api.start/1000.0 as ts, "1" from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id AND rocpd_op.gpuId = %s %s UNION ALL select rocpd_op.end/1000.0, "-1" from rocpd_api_ops INNER JOIN rocpd_api on rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op on rocpd_api_ops.op_id = rocpd_op.id AND rocpd_op.gpuId = %s %s) order by ts'
|
263
|
+
% (gpuId, rangeStringOp, gpuId, rangeStringOp)
|
264
|
+
):
|
265
|
+
try:
|
266
|
+
if idle and int(row[1]) > 0:
|
267
|
+
idle = 0
|
268
|
+
outfile.write(
|
269
|
+
',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n'
|
270
|
+
% (gpuId, row[0], idle)
|
271
|
+
)
|
272
|
+
if depth == 1 and int(row[1]) < 0:
|
273
|
+
idle = 1
|
274
|
+
outfile.write(
|
275
|
+
',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n'
|
276
|
+
% (gpuId, row[0], idle)
|
277
|
+
)
|
278
|
+
depth = depth + int(row[1])
|
279
|
+
outfile.write(
|
280
|
+
',{"pid":"%s","name":"QueueDepth","ph":"C","ts":%s,"args":{"depth":%s}}\n'
|
281
|
+
% (gpuId, row[0], depth)
|
282
|
+
)
|
283
|
+
except ValueError:
|
284
|
+
outfile.write("")
|
285
|
+
if T_end > 0:
|
286
|
+
outfile.write(
|
287
|
+
',{"pid":"%s","name":"Idle","ph":"C","ts":%s,"args":{"idle":%s}}\n'
|
288
|
+
% (gpuId, T_end, idle)
|
289
|
+
)
|
290
|
+
outfile.write(
|
291
|
+
',{"pid":"%s","name":"QueueDepth","ph":"C","ts":%s,"args":{"depth":%s}}\n'
|
292
|
+
% (gpuId, T_end, depth)
|
293
|
+
)
|
294
|
+
|
295
|
+
# Create SMI counters
|
296
|
+
try:
|
297
|
+
for row in connection.execute(
|
298
|
+
"select deviceId, monitorType, start/1000.0, value from rocpd_monitor %s"
|
299
|
+
% (rangeStringMonitor)
|
300
|
+
):
|
301
|
+
outfile.write(
|
302
|
+
',{"pid":"%s","name":"%s","ph":"C","ts":%s,"args":{"%s":%s}}\n'
|
303
|
+
% (row[0], row[1], row[2], row[1], row[3])
|
304
|
+
)
|
305
|
+
# Output the endpoints of the last range
|
306
|
+
for row in connection.execute(
|
307
|
+
"select distinct deviceId, monitorType, max(end)/1000.0, value from rocpd_monitor %s group by deviceId, monitorType"
|
308
|
+
% (rangeStringMonitor)
|
309
|
+
):
|
310
|
+
outfile.write(
|
311
|
+
',{"pid":"%s","name":"%s","ph":"C","ts":%s,"args":{"%s":%s}}\n'
|
312
|
+
% (row[0], row[1], row[2], row[1], row[3])
|
313
|
+
)
|
314
|
+
except:
|
315
|
+
print("Did not find SMI data")
|
316
|
+
|
317
|
+
# Create the (global) memory counter
|
318
|
+
"""
|
319
|
+
sizes = {} # address -> size
|
320
|
+
totalSize = 0
|
321
|
+
exp = re.compile("^ptr\((.*)\)\s+size\((.*)\)$")
|
322
|
+
exp2 = re.compile("^ptr\((.*)\)$")
|
323
|
+
for row in connection.execute("SELECT rocpd_api.end/1000.0 as ts, B.string, '1' FROM rocpd_api INNER JOIN rocpd_string A ON A.id=rocpd_api.apiName_id INNER JOIN rocpd_string B ON B.id=rocpd_api.args_id WHERE A.string='hipFree' UNION ALL SELECT rocpd_api.start/1000.0, B.string, '0' FROM rocpd_api INNER JOIN rocpd_string A ON A.id=rocpd_api.apiName_id INNER JOIN rocpd_string B ON B.id=rocpd_api.args_id WHERE A.string='hipMalloc' ORDER BY ts asc"):
|
324
|
+
try:
|
325
|
+
if row[2] == '0': #malloc
|
326
|
+
m = exp.match(row[1])
|
327
|
+
if m:
|
328
|
+
size = int(m.group(2), 16)
|
329
|
+
totalSize = totalSize + size
|
330
|
+
sizes[m.group(1)] = size
|
331
|
+
outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(row[0],totalSize))
|
332
|
+
else: #free
|
333
|
+
m = exp2.match(row[1])
|
334
|
+
if m:
|
335
|
+
try: # Sometimes free addresses are not valid or listed
|
336
|
+
size = sizes[m.group(1)]
|
337
|
+
sizes[m.group(1)] = 0
|
338
|
+
totalSize = totalSize - size;
|
339
|
+
outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(row[0],totalSize))
|
340
|
+
except KeyError:
|
341
|
+
pass
|
342
|
+
except ValueError:
|
343
|
+
outfile.write("")
|
344
|
+
if T_end > 0:
|
345
|
+
outfile.write(',{"pid":"0","name":"Allocated Memory","ph":"C","ts":%s,"args":{"depth":%s}}\n'%(T_end,totalSize))
|
346
|
+
"""
|
347
|
+
|
348
|
+
# Create "faux calling stack frame" on gpu ops traceS
|
349
|
+
stacks = {} # Call stacks built from UserMarker entres. Key is 'pid,tid'
|
350
|
+
currentFrame = {} # "Current GPU frame" (id, name, start, end). Key is 'pid,tid'
|
351
|
+
|
352
|
+
class GpuFrame:
|
353
|
+
def __init__(self):
|
354
|
+
self.id = 0
|
355
|
+
self.name = ""
|
356
|
+
self.start = 0
|
357
|
+
self.end = 0
|
358
|
+
self.gpus = []
|
359
|
+
self.totalOps = 0
|
360
|
+
|
361
|
+
# FIXME: include 'start' (in ns) so we can ORDER BY it and break ties?
|
362
|
+
for row in connection.execute(
|
363
|
+
"SELECT '0', start/1000.0, pid, tid, B.string as label, '','','', '' from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id AND A.string = 'UserMarker' INNER JOIN rocpd_string B on B.id = rocpd_api.args_id AND rocpd_api.start/1000.0 != rocpd_api.end/1000.0 %s UNION ALL SELECT '1', end/1000.0, pid, tid, B.string as label, '','','', '' from rocpd_api INNER JOIN rocpd_string A on A.id = rocpd_api.apiName_id AND A.string = 'UserMarker' INNER JOIN rocpd_string B on B.id = rocpd_api.args_id AND rocpd_api.start/1000.0 != rocpd_api.end/1000.0 %s UNION ALL SELECT '2', rocpd_api.start/1000.0, pid, tid, '' as label, gpuId, queueId, rocpd_op.start/1000.0, rocpd_op.end/1000.0 from rocpd_api_ops INNER JOIN rocpd_api ON rocpd_api_ops.api_id = rocpd_api.id INNER JOIN rocpd_op ON rocpd_api_ops.op_id = rocpd_op.id %s ORDER BY start/1000.0 asc"
|
364
|
+
% (rangeStringApi, rangeStringApi, rangeStringApi)
|
365
|
+
):
|
366
|
+
try:
|
367
|
+
key = (row[2], row[3]) # Key is 'pid,tid'
|
368
|
+
if row[0] == "0": # Frame start
|
369
|
+
if key not in stacks:
|
370
|
+
stacks[key] = []
|
371
|
+
stack = stacks[key].append((row[1], row[4]))
|
372
|
+
# print(f"0: new api frame: pid_tid={key} -> stack={stacks}")
|
373
|
+
|
374
|
+
elif row[0] == "1": # Frame end
|
375
|
+
completed = stacks[key].pop()
|
376
|
+
# print(f"1: end api frame: pid_tid={key} -> stack={stacks}")
|
377
|
+
|
378
|
+
elif row[0] == "2": # API + Op
|
379
|
+
if key in stacks and len(stacks[key]) > 0:
|
380
|
+
frame = stacks[key][-1]
|
381
|
+
# print(f"2: Op on {frame} ({len(stacks[key])})")
|
382
|
+
gpuFrame = None
|
383
|
+
if key not in currentFrame: # First op under the current api frame
|
384
|
+
gpuFrame = GpuFrame()
|
385
|
+
gpuFrame.id = frame[0]
|
386
|
+
gpuFrame.name = frame[1]
|
387
|
+
gpuFrame.start = row[7]
|
388
|
+
gpuFrame.end = row[8]
|
389
|
+
gpuFrame.gpus.append((row[5], row[6]))
|
390
|
+
gpuFrame.totalOps = 1
|
391
|
+
# print(f"2a: new frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}")
|
392
|
+
else:
|
393
|
+
gpuFrame = currentFrame[key]
|
394
|
+
# Another op under the same frame -> union them (but only if they are butt together)
|
395
|
+
if (
|
396
|
+
gpuFrame.id == frame[0]
|
397
|
+
and gpuFrame.name == frame[1]
|
398
|
+
and (
|
399
|
+
abs(row[7] - gpuFrame.end) < 200
|
400
|
+
or abs(gpuFrame.start - row[8]) < 200
|
401
|
+
)
|
402
|
+
):
|
403
|
+
# if gpuFrame.id == frame[0] and gpuFrame.name == frame[1]: # Another op under the same frame -> union them
|
404
|
+
# if False: # Turn off frame joining
|
405
|
+
if row[7] < gpuFrame.start:
|
406
|
+
gpuFrame.start = row[7]
|
407
|
+
if row[8] > gpuFrame.end:
|
408
|
+
gpuFrame.end = row[8]
|
409
|
+
if (row[5], row[6]) not in gpuFrame.gpus:
|
410
|
+
gpuFrame.gpus.append((row[5], row[6]))
|
411
|
+
gpuFrame.totalOps = gpuFrame.totalOps + 1
|
412
|
+
# print(f"2c: union frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}")
|
413
|
+
|
414
|
+
else: # This is a new frame - dump the last and make new
|
415
|
+
gpuFrame = currentFrame[key]
|
416
|
+
for dest in gpuFrame.gpus:
|
417
|
+
# print(f"2: OUTPUT: dest={dest} time={gpuFrame.start} -> {gpuFrame.end} Duration={gpuFrame.end - gpuFrame.start} TotalOps={gpuFrame.totalOps}")
|
418
|
+
outfile.write(
|
419
|
+
',{"pid":"%s","tid":"%s","name":"%s","ts":"%s","dur":"%s","ph":"X","args":{"desc":"%s"}}\n'
|
420
|
+
% (
|
421
|
+
dest[0],
|
422
|
+
dest[1],
|
423
|
+
gpuFrame.name.replace('"', ""),
|
424
|
+
gpuFrame.start - 1,
|
425
|
+
gpuFrame.end - gpuFrame.start + 1,
|
426
|
+
f"UserMarker frame: {gpuFrame.totalOps} ops",
|
427
|
+
)
|
428
|
+
)
|
429
|
+
currentFrame.pop(key)
|
430
|
+
|
431
|
+
# make the first op under the new frame
|
432
|
+
gpuFrame = GpuFrame()
|
433
|
+
gpuFrame.id = frame[0]
|
434
|
+
gpuFrame.name = frame[1]
|
435
|
+
gpuFrame.start = row[7]
|
436
|
+
gpuFrame.end = row[8]
|
437
|
+
gpuFrame.gpus.append((row[5], row[6]))
|
438
|
+
gpuFrame.totalOps = 1
|
439
|
+
# print(f"2b: new frame: {gpuFrame.gpus} {gpuFrame.start} {gpuFrame.end} {gpuFrame.end - gpuFrame.start}")
|
440
|
+
|
441
|
+
currentFrame[key] = gpuFrame
|
442
|
+
|
443
|
+
except ValueError:
|
444
|
+
outfile.write("")
|
445
|
+
|
446
|
+
outfile.write("]\n")
|
447
|
+
|
448
|
+
if format == "object":
|
449
|
+
outfile.write("} \n")
|
450
|
+
|
451
|
+
outfile.close()
|
452
|
+
connection.close()
|
@@ -0,0 +1,71 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, List
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.distributed as dist
|
6
|
+
import triton
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
def execute():
|
12
|
+
if dist.get_rank() == 0:
|
13
|
+
logger.info(f"[slow_rank_detector] Start benchmarking...")
|
14
|
+
|
15
|
+
local_metrics = {
|
16
|
+
bench_name: _compute_local_metric(bench_name) for bench_name in _BENCH_NAMES
|
17
|
+
}
|
18
|
+
|
19
|
+
all_metrics = [None for _ in range(dist.get_world_size())]
|
20
|
+
dist.gather_object(local_metrics, all_metrics if dist.get_rank() == 0 else None)
|
21
|
+
|
22
|
+
if dist.get_rank() == 0:
|
23
|
+
_analyze_metrics(all_metrics)
|
24
|
+
|
25
|
+
|
26
|
+
class _GemmExecutor:
|
27
|
+
def __init__(self):
|
28
|
+
self.lhs = torch.randn((8192, 8192), dtype=torch.bfloat16, device="cuda")
|
29
|
+
self.rhs = torch.randn((8192, 8192), dtype=torch.bfloat16, device="cuda")
|
30
|
+
|
31
|
+
def __call__(self):
|
32
|
+
self.lhs @ self.rhs
|
33
|
+
|
34
|
+
|
35
|
+
class _ElementwiseExecutor:
|
36
|
+
def __init__(self):
|
37
|
+
self.value = torch.randint(
|
38
|
+
0, 10000, (128 * 1024**2,), dtype=torch.int32, device="cuda"
|
39
|
+
)
|
40
|
+
|
41
|
+
def __call__(self):
|
42
|
+
self.value += 1
|
43
|
+
|
44
|
+
|
45
|
+
_EXECUTOR_CLS_OF_BENCH = {
|
46
|
+
"gemm": _GemmExecutor,
|
47
|
+
"elementwise": _ElementwiseExecutor,
|
48
|
+
}
|
49
|
+
|
50
|
+
_BENCH_NAMES = list(_EXECUTOR_CLS_OF_BENCH.keys())
|
51
|
+
|
52
|
+
|
53
|
+
def _compute_local_metric(bench_name):
|
54
|
+
executor = _EXECUTOR_CLS_OF_BENCH[bench_name]()
|
55
|
+
ms = triton.testing.do_bench_cudagraph(executor, return_mode="mean", rep=20)
|
56
|
+
return ms
|
57
|
+
|
58
|
+
|
59
|
+
def _analyze_metrics(all_metrics: List[Dict[str, Any]]):
|
60
|
+
for bench_name in _BENCH_NAMES:
|
61
|
+
time_of_rank = torch.tensor([m[bench_name] for m in all_metrics])
|
62
|
+
speed_of_rank = 1 / time_of_rank
|
63
|
+
rel_speed_of_rank = speed_of_rank / speed_of_rank.max()
|
64
|
+
slowest_rel_speed = rel_speed_of_rank.min().item()
|
65
|
+
logger.info(
|
66
|
+
f"[slow_rank_detector] {bench_name=} {slowest_rel_speed=} {rel_speed_of_rank=} {time_of_rank=}"
|
67
|
+
)
|
68
|
+
if slowest_rel_speed < 0.9:
|
69
|
+
logger.warning(
|
70
|
+
"[slow_rank_detector] Some ranks are too slow compared with others"
|
71
|
+
)
|
sglang/srt/warmup.py
CHANGED
@@ -1,20 +1,24 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
|
-
from typing import List
|
4
|
+
from typing import TYPE_CHECKING, List
|
3
5
|
|
4
6
|
import numpy as np
|
5
7
|
import tqdm
|
6
8
|
|
7
9
|
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
8
10
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
9
|
-
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
10
14
|
|
11
15
|
logger = logging.getLogger(__file__)
|
12
16
|
|
13
17
|
_warmup_registry = {}
|
14
18
|
|
15
19
|
|
16
|
-
def warmup(name: str)
|
17
|
-
def decorator(fn
|
20
|
+
def warmup(name: str):
|
21
|
+
def decorator(fn):
|
18
22
|
_warmup_registry[name] = fn
|
19
23
|
return fn
|
20
24
|
|
sglang/srt/weight_sync/utils.py
CHANGED
@@ -33,7 +33,7 @@ async def update_weights(
|
|
33
33
|
"""
|
34
34
|
infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
|
35
35
|
infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
|
36
|
-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
36
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
37
37
|
|
38
38
|
monkey_patch_torch_reductions()
|
39
39
|
|