sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import base64
19
+ import json
18
20
  import logging
19
21
  import os
20
22
  import random
@@ -24,6 +26,8 @@ import uuid
24
26
  from dataclasses import dataclass
25
27
  from typing import TYPE_CHECKING, Any, Dict, List, Optional
26
28
 
29
+ from sglang.srt.utils import get_int_env_var
30
+
27
31
  if TYPE_CHECKING:
28
32
  from sglang.srt.managers.scheduler import Req
29
33
 
@@ -85,6 +89,8 @@ class SglangTraceReqContext:
85
89
  # Indicates whether this instance is a replica from the main process.
86
90
  # When True, root_span is None and only root_span_context is preserved.
87
91
  is_copy: bool = False
92
+ bootstrap_room_span: Optional[trace.span.Span] = None
93
+ bootstrap_room_span_context: Optional[context.Context] = None
88
94
  root_span: Optional[trace.span.Span] = None
89
95
  root_span_context: Optional[context.Context] = None
90
96
 
@@ -96,8 +102,7 @@ class SglangTracePropagateContext:
96
102
 
97
103
  def to_dict(self):
98
104
  carrier: dict[str, str] = {}
99
- context.attach(self.root_span_context)
100
- propagate.inject(carrier)
105
+ propagate.inject(carrier, self.root_span_context)
101
106
 
102
107
  if self.prev_span_context:
103
108
  return {
@@ -149,6 +154,7 @@ class SglangTraceCustomIdGenerator(id_generator.IdGenerator):
149
154
 
150
155
 
151
156
  # global variables
157
+ remote_trace_contexts: Dict[str, SglangTracePropagateContext] = {}
152
158
  threads_info: Dict[int, SglangTraceThreadInfo] = {}
153
159
  reqs_context: Dict[str, SglangTraceReqContext] = {}
154
160
 
@@ -193,8 +199,17 @@ def process_tracing_init(otlp_endpoint, server_name):
193
199
  resource=resource, id_generator=SglangTraceCustomIdGenerator()
194
200
  )
195
201
 
202
+ schedule_delay_millis = get_int_env_var(
203
+ "SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", 500
204
+ )
205
+ max_export_batch_size = get_int_env_var(
206
+ "SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", 64
207
+ )
208
+
196
209
  processor = BatchSpanProcessor(
197
- OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
210
+ OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True),
211
+ schedule_delay_millis=schedule_delay_millis,
212
+ max_export_batch_size=max_export_batch_size,
198
213
  )
199
214
  tracer_provider.add_span_processor(processor)
200
215
  trace.set_tracer_provider(tracer_provider)
@@ -266,7 +281,9 @@ def __create_thread_context(pid, req_span_context, ts: Optional[int] = None):
266
281
  return thread_context
267
282
 
268
283
 
269
- def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]:
284
+ def trace_get_proc_propagate_context(
285
+ rid, remote_propagate=False
286
+ ) -> Optional[Dict[str, Any]]:
270
287
  if not tracing_enabled:
271
288
  return None
272
289
 
@@ -283,9 +300,11 @@ def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]:
283
300
  elif thread_context.last_span_context:
284
301
  prev_span_context = thread_context.last_span_context
285
302
 
286
- trace_context = SglangTracePropagateContext(
287
- reqs_context[rid].root_span_context, prev_span_context
288
- )
303
+ root_span_context = reqs_context[rid].root_span_context
304
+ if remote_propagate:
305
+ root_span_context = reqs_context[rid].bootstrap_room_span_context
306
+
307
+ trace_context = SglangTracePropagateContext(root_span_context, prev_span_context)
289
308
  return trace_context.to_dict()
290
309
 
291
310
 
@@ -327,10 +346,54 @@ def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any]
327
346
  ].last_span_context = trace_context.prev_span_context
328
347
 
329
348
 
349
+ def trace_get_remote_propagate_context(bootstrap_room_list: List[str]):
350
+ if not tracing_enabled:
351
+ return ""
352
+
353
+ reqs_trace_contexts = {}
354
+ for bootstrap_room in bootstrap_room_list:
355
+ # In the router, rid is also the bootstrap room.
356
+ bootstrap_room = str(bootstrap_room)
357
+
358
+ if bootstrap_room not in reqs_context:
359
+ continue
360
+
361
+ _context = trace_get_proc_propagate_context(
362
+ bootstrap_room, remote_propagate=True
363
+ )
364
+ reqs_trace_contexts[bootstrap_room] = _context
365
+
366
+ json_str = json.dumps(reqs_trace_contexts, ensure_ascii=False)
367
+ return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
368
+
369
+
370
+ def trace_set_remote_propagate_context(base64_str):
371
+ if not tracing_enabled:
372
+ return
373
+
374
+ if base64_str is None or base64_str == "" or base64_str == "None":
375
+ return
376
+
377
+ base64_bytes = base64.b64decode(base64_str)
378
+ json_str = base64_bytes.decode("utf-8")
379
+ remote_reqs_trace_contexts = json.loads(json_str)
380
+
381
+ for bootstrap_room in remote_reqs_trace_contexts:
382
+ if bootstrap_room in remote_trace_contexts:
383
+ continue
384
+
385
+ remote_trace_contexts[bootstrap_room] = (
386
+ SglangTracePropagateContext.instance_from_dict(
387
+ remote_reqs_trace_contexts[bootstrap_room]
388
+ )
389
+ )
390
+
391
+
330
392
  def trace_req_start(
331
393
  rid: str,
332
394
  bootstrap_room: Optional[int] = None,
333
395
  ts: Optional[int] = None,
396
+ role: Optional[str] = "null",
334
397
  ):
335
398
  if not tracing_enabled:
336
399
  return
@@ -344,6 +407,7 @@ def trace_req_start(
344
407
  return
345
408
 
346
409
  # create req context and root span
410
+ bootstrap_room = 0 if bootstrap_room is None else bootstrap_room
347
411
  reqs_context[rid] = SglangTraceReqContext(
348
412
  rid=rid,
349
413
  start_time_ns=ts,
@@ -352,23 +416,42 @@ def trace_req_start(
352
416
  is_copy=False,
353
417
  )
354
418
 
419
+ # create bootstrap room span
420
+ tracer = threads_info[pid].tracer
421
+ if str(bootstrap_room) not in remote_trace_contexts:
422
+ attrs = {"bootstrap_room": str(hex(bootstrap_room))}
423
+ bootstrap_room_span = tracer.start_span(
424
+ name=f"Bootstrap Room {hex(bootstrap_room)}",
425
+ start_time=ts,
426
+ attributes=attrs,
427
+ )
428
+ reqs_context[rid].bootstrap_room_span = bootstrap_room_span
429
+ bootstrap_room_span_context = trace.set_span_in_context(bootstrap_room_span)
430
+ else:
431
+ bootstrap_room_span_context = remote_trace_contexts[
432
+ str(bootstrap_room)
433
+ ].root_span_context
434
+
355
435
  # Drop the worker_id added by MultiTokenizer
356
436
  orig_rid = rid.split("_")[-1]
357
- tracer = threads_info[pid].tracer
437
+ role = "" if role == "null" else role
438
+ attrs = {"rid": orig_rid}
358
439
  root_span = tracer.start_span(
359
- name=f"Req {orig_rid[:8]}",
440
+ name=f"{role} Req {orig_rid[:8]}",
360
441
  start_time=ts,
442
+ context=bootstrap_room_span_context,
443
+ attributes=attrs,
361
444
  )
362
445
 
363
446
  root_span.set_attributes(
364
447
  {
365
448
  "rid": rid,
366
- "bootstrap_room": bootstrap_room if bootstrap_room else "None",
367
449
  }
368
450
  )
369
451
 
370
452
  reqs_context[rid].root_span = root_span
371
453
  reqs_context[rid].root_span_context = trace.set_span_in_context(root_span)
454
+ reqs_context[rid].bootstrap_room_span_context = bootstrap_room_span_context
372
455
 
373
456
  # create thread context and thread span
374
457
  reqs_context[rid].threads_context[pid] = __create_thread_context(
@@ -376,6 +459,10 @@ def trace_req_start(
376
459
  reqs_context[rid].root_span_context,
377
460
  ts,
378
461
  )
462
+ if str(bootstrap_room) in remote_trace_contexts:
463
+ reqs_context[rid].threads_context[pid].last_span_context = (
464
+ remote_trace_contexts[str(bootstrap_room)].prev_span_context
465
+ )
379
466
 
380
467
 
381
468
  def trace_req_finish(
@@ -399,6 +486,10 @@ def trace_req_finish(
399
486
  req_context.root_span.set_attributes(attrs)
400
487
 
401
488
  req_context.root_span.end(end_time=ts)
489
+ if str(req_context.bootstrap_room) in remote_trace_contexts:
490
+ del remote_trace_contexts[str(req_context.bootstrap_room)]
491
+ else:
492
+ req_context.bootstrap_room_span.end(end_time=ts)
402
493
 
403
494
  del reqs_context[rid]
404
495
 
@@ -518,7 +609,9 @@ trace_slice = trace_slice_end
518
609
 
519
610
 
520
611
  # Add event to the current slice on the same thread with the same rid.
521
- def trace_event(name: str, rid: str, ts: Optional[int] = None):
612
+ def trace_event(
613
+ name: str, rid: str, ts: Optional[int] = None, attrs: Dict[str, Any] = None
614
+ ):
522
615
  if not tracing_enabled:
523
616
  return
524
617
 
@@ -539,7 +632,7 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None):
539
632
  ts = ts or __get_cur_time_ns()
540
633
 
541
634
  slice_info = thread_context.cur_slice_stack[-1]
542
- slice_info.span.add_event(name=name, timestamp=ts)
635
+ slice_info.span.add_event(name=name, timestamp=ts, attributes=attrs)
543
636
 
544
637
 
545
638
  # Add attrs to the current slice on the same thread with the same rid.
@@ -569,6 +662,9 @@ def trace_slice_batch(
569
662
  name: str,
570
663
  reqs: List[Req],
571
664
  ):
665
+ if not tracing_enabled:
666
+ return
667
+
572
668
  for req in reqs:
573
669
  trace_slice(
574
670
  name,
@@ -576,3 +672,16 @@ def trace_slice_batch(
576
672
  auto_next_anon=not req.finished(),
577
673
  thread_finish_flag=req.finished(),
578
674
  )
675
+
676
+
677
+ def trace_event_batch(
678
+ name: str,
679
+ reqs: List[Req],
680
+ ts: Optional[int] = None,
681
+ attrs: Dict[str, Any] = None,
682
+ ):
683
+ if not tracing_enabled:
684
+ return
685
+
686
+ for req in reqs:
687
+ trace_event(name, req.rid, ts=ts, attrs=attrs)
@@ -188,7 +188,16 @@ is_hopper_with_cuda_12_3 = lambda: _check(9)
188
188
  def is_blackwell():
189
189
  if not is_cuda():
190
190
  return False
191
- return torch.cuda.get_device_capability()[0] == 10
191
+ return torch.cuda.get_device_capability()[0] in [10, 12]
192
+
193
+
194
+ @lru_cache(maxsize=1)
195
+ def is_blackwell_supported(device=None) -> bool:
196
+ if not is_cuda_alike():
197
+ return False
198
+ return (torch.cuda.get_device_capability(device)[0] in [10, 12]) and (
199
+ torch.version.cuda >= "12.8"
200
+ )
192
201
 
193
202
 
194
203
  @lru_cache(maxsize=1)
@@ -1230,42 +1239,34 @@ def point_to_point_pyobj(
1230
1239
  dst: int = 1,
1231
1240
  ):
1232
1241
  """Send data from src to dst in group using DeviceToDevice communication."""
1233
-
1242
+ device = torch.get_device_module().current_device()
1234
1243
  if rank == src:
1235
1244
  if len(data) == 0:
1236
- tensor_size = torch.tensor(
1237
- [0], dtype=torch.long, device=torch.cuda.current_device()
1238
- )
1245
+ tensor_size = torch.tensor([0], dtype=torch.long, device=device)
1239
1246
  dist.send(tensor_size, dst=dst, group=group)
1240
1247
  else:
1241
1248
  serialized_data = pickle.dumps(data)
1242
1249
  size = len(serialized_data)
1243
1250
  tensor_data = torch.ByteTensor(
1244
1251
  np.frombuffer(serialized_data, dtype=np.uint8)
1245
- ).cuda(
1246
- device=torch.cuda.current_device()
1252
+ ).to(
1253
+ device=device
1247
1254
  ) # Move to GPU
1248
- tensor_size = torch.tensor(
1249
- [size], dtype=torch.long, device=torch.cuda.current_device()
1250
- )
1255
+ tensor_size = torch.tensor([size], dtype=torch.long, device=device)
1251
1256
 
1252
1257
  dist.send(tensor_size, dst=dst, group=group)
1253
1258
  dist.send(tensor_data, dst=dst, group=group)
1254
1259
  return data
1255
1260
 
1256
1261
  elif rank == dst:
1257
- tensor_size = torch.tensor(
1258
- [0], dtype=torch.long, device=torch.cuda.current_device()
1259
- )
1262
+ tensor_size = torch.tensor([0], dtype=torch.long, device=device)
1260
1263
  dist.recv(tensor_size, src=src, group=group)
1261
1264
  size = tensor_size.item()
1262
1265
 
1263
1266
  if size == 0:
1264
1267
  return []
1265
1268
 
1266
- tensor_data = torch.empty(
1267
- size, dtype=torch.uint8, device=torch.cuda.current_device()
1268
- )
1269
+ tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
1269
1270
  dist.recv(tensor_data, src=src, group=group)
1270
1271
 
1271
1272
  serialized_data = bytes(
@@ -2350,16 +2351,24 @@ def launch_dummy_health_check_server(host, port, enable_metrics):
2350
2351
  )
2351
2352
  server = uvicorn.Server(config=config)
2352
2353
 
2353
- try:
2354
- loop = asyncio.get_running_loop()
2355
- logger.info(
2356
- f"Dummy health check server scheduled on existing loop at {host}:{port}"
2357
- )
2358
- loop.create_task(server.serve())
2354
+ # Run server in a background daemon thread with its own event loop
2355
+ # This prevents blocking the main thread while still serving health checks
2356
+ def run_server():
2357
+ try:
2358
+ asyncio.run(server.serve())
2359
+ except Exception as e:
2360
+ logger.error(f"Dummy health check server failed to start: {e}")
2361
+ raise
2362
+ finally:
2363
+ logger.info(f"Dummy health check server stopped at {host}:{port}")
2359
2364
 
2360
- except RuntimeError:
2361
- logger.info(f"Starting dummy health check server at {host}:{port}")
2362
- server.run()
2365
+ thread = threading.Thread(
2366
+ target=run_server, daemon=True, name="health-check-server"
2367
+ )
2368
+ thread.start()
2369
+ logger.info(
2370
+ f"Dummy health check server started in background thread at {host}:{port}"
2371
+ )
2363
2372
 
2364
2373
 
2365
2374
  def create_checksum(directory: str):
@@ -3105,12 +3114,16 @@ def apply_module_patch(target_module, target_function, wrappers):
3105
3114
  setattr(original_module, target_function, candidate)
3106
3115
 
3107
3116
  for key, value in sys.modules.copy().items():
3108
- if (
3109
- target_function is not None
3110
- and hasattr(value, target_function)
3111
- and id(getattr(value, target_function)) == original_function_id
3112
- ):
3113
- setattr(value, target_function, candidate)
3117
+ try:
3118
+ if (
3119
+ target_function is not None
3120
+ and hasattr(value, target_function)
3121
+ and id(getattr(value, target_function)) == original_function_id
3122
+ ):
3123
+ setattr(value, target_function, candidate)
3124
+ except ImportError as e:
3125
+ # Ignore some modules reporting ImportError when calling hasattr
3126
+ logger.warning(f"Ignore {value} reports ImportError with:\n{str(e)}")
3114
3127
 
3115
3128
 
3116
3129
  def parse_module_path(module_path, function_name, create_dummy):
@@ -3562,7 +3575,17 @@ def cached_triton_kernel(key_fn=None):
3562
3575
  """
3563
3576
 
3564
3577
  def decorator(fn):
3565
- return CachedKernel(fn, key_fn)
3578
+ if envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.get():
3579
+ logger.debug(
3580
+ f"{envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.name} = True. Using custom triton kernel cache."
3581
+ )
3582
+ return CachedKernel(fn, key_fn)
3583
+ else:
3584
+ # Fallback to the native triton cache.
3585
+ logger.debug(
3586
+ f"{envs.SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE.name} = False. Using native triton kernel cache."
3587
+ )
3588
+ return fn
3566
3589
 
3567
3590
  return decorator
3568
3591
 
@@ -43,6 +43,7 @@ from sglang.srt.configs import (
43
43
  DotsVLMConfig,
44
44
  ExaoneConfig,
45
45
  FalconH1Config,
46
+ KimiLinearConfig,
46
47
  KimiVLConfig,
47
48
  LongcatFlashConfig,
48
49
  MultiModalityConfig,
@@ -54,6 +55,7 @@ from sglang.srt.configs import (
54
55
  from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
55
56
  from sglang.srt.configs.internvl import InternVLChatConfig
56
57
  from sglang.srt.connector import create_remote_connector
58
+ from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR
57
59
  from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
58
60
 
59
61
  _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
@@ -67,6 +69,7 @@ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
67
69
  Step3VLConfig,
68
70
  LongcatFlashConfig,
69
71
  Olmo3Config,
72
+ KimiLinearConfig,
70
73
  Qwen3NextConfig,
71
74
  FalconH1Config,
72
75
  DotsVLMConfig,
@@ -172,6 +175,16 @@ def _load_deepseek_v32_model(
172
175
  )
173
176
 
174
177
 
178
+ def _is_deepseek_ocr_model(config: PretrainedConfig) -> bool:
179
+ # TODO: Remove this workaround related when AutoConfig correctly identifies deepseek-ocr.
180
+ # Hugging Face's AutoConfig currently misidentifies it as deepseekvl2.
181
+ return (
182
+ getattr(config, "auto_map", None) is not None
183
+ and config.auto_map.get("AutoModel")
184
+ == "modeling_deepseekocr.DeepseekOCRForCausalLM"
185
+ )
186
+
187
+
175
188
  @lru_cache_frozenset(maxsize=32)
176
189
  def get_config(
177
190
  model: str,
@@ -197,14 +210,6 @@ def get_config(
197
210
  config = AutoConfig.from_pretrained(
198
211
  model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
199
212
  )
200
- if (
201
- getattr(config, "auto_map", None) is not None
202
- and config.auto_map.get("AutoModel")
203
- == "modeling_deepseekocr.DeepseekOCRForCausalLM"
204
- ):
205
- config.model_type = "deepseek-ocr"
206
- # TODO: Remove this workaround when AutoConfig correctly identifies deepseek-ocr.
207
- # Hugging Face's AutoConfig currently misidentifies it as deepseekvl2.
208
213
 
209
214
  except ValueError as e:
210
215
  if not "deepseek_v32" in str(e):
@@ -241,7 +246,11 @@ def get_config(
241
246
  setattr(config, key, val)
242
247
 
243
248
  if config.model_type in _CONFIG_REGISTRY:
244
- config_class = _CONFIG_REGISTRY[config.model_type]
249
+ model_type = config.model_type
250
+ if model_type == "deepseek_vl_v2":
251
+ if _is_deepseek_ocr_model(config):
252
+ model_type = "deepseek-ocr"
253
+ config_class = _CONFIG_REGISTRY[model_type]
245
254
  config = config_class.from_pretrained(model, revision=revision)
246
255
  # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
247
256
  setattr(config, "_name_or_path", model)
@@ -445,6 +454,10 @@ def get_processor(
445
454
  **kwargs,
446
455
  )
447
456
 
457
+ if _is_deepseek_ocr_model(config):
458
+ # Temporary hack for load deepseek-ocr
459
+ config.model_type = "deepseek-ocr"
460
+
448
461
  # fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided.
449
462
  if config.model_type in {"qwen2_vl", "sarashina2_vision"}:
450
463
  if "size" not in kwargs:
@@ -462,13 +475,22 @@ def get_processor(
462
475
  **kwargs,
463
476
  )
464
477
  else:
465
- processor = AutoProcessor.from_pretrained(
466
- tokenizer_name,
467
- *args,
468
- trust_remote_code=trust_remote_code,
469
- revision=revision,
470
- **kwargs,
471
- )
478
+ if config.model_type in _CUSTOMIZED_MM_PROCESSOR:
479
+ processor = _CUSTOMIZED_MM_PROCESSOR[config.model_type].from_pretrained(
480
+ tokenizer_name,
481
+ *args,
482
+ trust_remote_code=trust_remote_code,
483
+ revision=revision,
484
+ **kwargs,
485
+ )
486
+ else:
487
+ processor = AutoProcessor.from_pretrained(
488
+ tokenizer_name,
489
+ *args,
490
+ trust_remote_code=trust_remote_code,
491
+ revision=revision,
492
+ **kwargs,
493
+ )
472
494
 
473
495
  except ValueError as e:
474
496
  error_message = str(e)
@@ -41,6 +41,12 @@ class TorchMemorySaverAdapter(ABC):
41
41
  def region(self, tag: str, enable_cpu_backup: bool = False):
42
42
  raise NotImplementedError
43
43
 
44
+ def cuda_graph(self, **kwargs):
45
+ raise NotImplementedError
46
+
47
+ def disable(self):
48
+ raise NotImplementedError
49
+
44
50
  def pause(self, tag: str):
45
51
  raise NotImplementedError
46
52
 
@@ -61,6 +67,12 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
61
67
  def region(self, tag: str, enable_cpu_backup: bool = False):
62
68
  return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
63
69
 
70
+ def cuda_graph(self, **kwargs):
71
+ return _memory_saver.cuda_graph(**kwargs)
72
+
73
+ def disable(self):
74
+ return _memory_saver.disable()
75
+
64
76
  def pause(self, tag: str):
65
77
  return _memory_saver.pause(tag=tag)
66
78
 
@@ -81,6 +93,14 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
81
93
  def region(self, tag: str, enable_cpu_backup: bool = False):
82
94
  yield
83
95
 
96
+ @contextmanager
97
+ def cuda_graph(self, **kwargs):
98
+ yield
99
+
100
+ @contextmanager
101
+ def disable(self):
102
+ yield
103
+
84
104
  def pause(self, tag: str):
85
105
  pass
86
106
 
@@ -0,0 +1,50 @@
1
+ import random
2
+
3
+ import requests
4
+
5
+
6
+ def gen_radix_tree(num_nodes=400, chunk_len=256):
7
+ num0 = num_nodes // 2
8
+ num1 = num_nodes - num0
9
+ nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
10
+ for _ in range(num0):
11
+ parent = random.choice(nodes)
12
+ unique_len = random.randint(0, chunk_len)
13
+ decode_len = random.randint(0, chunk_len)
14
+ token_id = random.randint(0, 32000)
15
+ child = {
16
+ "input_ids": parent["input_ids"] + [token_id] * unique_len,
17
+ "decode_len": decode_len,
18
+ }
19
+ nodes.append(child)
20
+
21
+ while num1 > 0:
22
+ num_branch = random.randint(1, min(num1, 10))
23
+ parent = random.choice(nodes)
24
+ for _ in range(num_branch):
25
+ unique_len = random.randint(0, chunk_len)
26
+ decode_len = random.randint(0, chunk_len)
27
+ token_id = random.randint(0, 32000)
28
+ child = {
29
+ "input_ids": parent["input_ids"] + [token_id] * unique_len,
30
+ "decode_len": decode_len,
31
+ }
32
+ nodes.append(child)
33
+
34
+ num1 -= num_branch
35
+
36
+ random.shuffle(nodes)
37
+ return nodes
38
+
39
+
40
+ def run_radix_attention_test(base_url: str):
41
+ nodes = gen_radix_tree()
42
+ data = {
43
+ "input_ids": [node["input_ids"] for node in nodes],
44
+ "sampling_params": [
45
+ {"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
46
+ ],
47
+ }
48
+
49
+ res = requests.post(base_url + "/generate", json=data)
50
+ assert res.status_code == 200