sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  import dataclasses
4
5
  import logging
5
6
  from dataclasses import replace
@@ -17,15 +18,21 @@ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
17
18
  from sglang.srt.layers.moe.utils import DeepEPMode
18
19
  from sglang.srt.layers.quantization import deep_gemm_wrapper
19
20
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
20
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
21
+ from sglang.srt.model_executor.forward_batch_info import (
22
+ ForwardBatch,
23
+ ForwardMode,
24
+ compute_position,
25
+ )
21
26
  from sglang.srt.operations import execute_operations, execute_overlapped_operations
22
27
  from sglang.srt.operations_strategy import OperationsStrategy
23
28
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
24
- from sglang.srt.utils import BumpAllocator, get_bool_env_var
29
+ from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
25
30
 
26
31
  if TYPE_CHECKING:
27
32
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
28
33
 
34
+ _is_hip = is_hip()
35
+
29
36
  _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
30
37
 
31
38
  logger = logging.getLogger(__name__)
@@ -58,7 +65,7 @@ def compute_split_seq_index(
58
65
  ) -> Optional[int]:
59
66
  if forward_mode == ForwardMode.EXTEND:
60
67
  assert extend_lens is not None
61
- return _split_array_by_half_sum(extend_lens)
68
+ return _split_extend_seqs(extend_lens)
62
69
  elif forward_mode.is_target_verify() or forward_mode.is_decode():
63
70
  assert token_num_per_seq is not None
64
71
  return (num_tokens // token_num_per_seq) // 2
@@ -69,7 +76,43 @@ def compute_split_seq_index(
69
76
  raise NotImplementedError()
70
77
 
71
78
 
72
- def _split_array_by_half_sum(arr: Sequence[int]) -> int:
79
+ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
80
+ if extend_lens is None:
81
+ return False
82
+
83
+ vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
84
+ left_sum = sum(extend_lens[:vanilla_split_seq_index])
85
+ overall_sum = sum(extend_lens)
86
+ threshold = global_server_args_dict["tbo_token_distribution_threshold"]
87
+ assert threshold <= 0.5, f"{threshold=}"
88
+ return left_sum < overall_sum * threshold or left_sum > overall_sum * (
89
+ 1 - threshold
90
+ )
91
+
92
+
93
+ def _split_extend_seqs(arr: Sequence[int]) -> int:
94
+ if _is_two_chunk_split_enabled(arr):
95
+ return _split_array_by_cum_less_than_half(arr)
96
+
97
+ return _split_array_by_balanced_sum(arr)
98
+
99
+
100
+ def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
101
+ left_sum = 0
102
+ overall_sum = sum(arr)
103
+ half_sum = overall_sum // 2
104
+ chosen_index = 0
105
+
106
+ for i in range(len(arr)):
107
+ left_sum += arr[i]
108
+ if left_sum > half_sum:
109
+ chosen_index = i
110
+ break
111
+
112
+ return chosen_index
113
+
114
+
115
+ def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
73
116
  overall_sum = sum(arr)
74
117
  left_sum = 0
75
118
  min_diff = float("inf")
@@ -88,6 +131,34 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
88
131
  return best_index
89
132
 
90
133
 
134
+ def _update_device_and_sum_field_from_cpu_field(
135
+ batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
136
+ ):
137
+ cpu_value = getattr(batch, cpu_field, None)
138
+ old_device_value = getattr(batch, device_field, None)
139
+ if (
140
+ cpu_value is None
141
+ or old_device_value is None
142
+ or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
143
+ ):
144
+ return
145
+
146
+ new_device_value = (
147
+ cpu_value
148
+ if isinstance(cpu_value, torch.Tensor)
149
+ else torch.tensor(cpu_value, dtype=old_device_value.dtype)
150
+ ).to(device=global_server_args_dict["device"], non_blocking=True)
151
+ setattr(batch, device_field, new_device_value)
152
+
153
+ if sum_field is not None:
154
+ sum_value = (
155
+ cpu_value.sum().item()
156
+ if isinstance(cpu_value, torch.Tensor)
157
+ else sum(cpu_value)
158
+ )
159
+ setattr(batch, sum_field, sum_value)
160
+
161
+
91
162
  def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
92
163
  if seq_index == 0:
93
164
  return 0
@@ -181,6 +252,8 @@ def compute_split_token_index(
181
252
  ) -> int:
182
253
  if forward_mode == ForwardMode.EXTEND:
183
254
  assert extend_seq_lens is not None
255
+ if _is_two_chunk_split_enabled(extend_seq_lens):
256
+ return sum(extend_seq_lens) // 2
184
257
  return sum(extend_seq_lens[:split_seq_index])
185
258
  elif forward_mode.is_target_verify() or forward_mode.is_decode():
186
259
  assert token_num_per_seq is not None
@@ -388,9 +461,15 @@ class TboForwardBatchPreparer:
388
461
 
389
462
  tbo_split_token_index = cls._compute_split_token_index(batch)
390
463
 
464
+ is_enable_two_chunk = (
465
+ batch.forward_mode == ForwardMode.EXTEND
466
+ and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
467
+ )
468
+
391
469
  if _tbo_debug:
392
470
  logger.info(
393
471
  f"TboForwardBatchPreparer.prepare "
472
+ f"is_enable_two_chunk={is_enable_two_chunk} "
394
473
  f"tbo_split_seq_index={batch.tbo_split_seq_index} "
395
474
  f"tbo_split_token_index={tbo_split_token_index} "
396
475
  f"extend_seq_lens={batch.extend_seq_lens_cpu} "
@@ -410,7 +489,11 @@ class TboForwardBatchPreparer:
410
489
  start_token_index=0,
411
490
  end_token_index=tbo_split_token_index,
412
491
  start_seq_index=0,
413
- end_seq_index=batch.tbo_split_seq_index,
492
+ end_seq_index=(
493
+ batch.tbo_split_seq_index + 1
494
+ if is_enable_two_chunk
495
+ else batch.tbo_split_seq_index
496
+ ),
414
497
  output_attn_backend=attn_backend_child_a,
415
498
  out_num_token_non_padded=out_num_token_non_padded_a,
416
499
  )
@@ -424,9 +507,79 @@ class TboForwardBatchPreparer:
424
507
  out_num_token_non_padded=out_num_token_non_padded_b,
425
508
  )
426
509
 
510
+ if is_enable_two_chunk:
511
+ cls.derive_fields_related_to_seq_len_for_two_chunk(
512
+ batch,
513
+ child_a=child_a,
514
+ child_b=child_b,
515
+ tbo_split_seq_index=batch.tbo_split_seq_index,
516
+ )
517
+
427
518
  assert batch.tbo_children is None
428
519
  batch.tbo_children = [child_a, child_b]
429
520
 
521
+ @classmethod
522
+ def derive_fields_related_to_seq_len_for_two_chunk(
523
+ cls,
524
+ batch: ForwardBatch,
525
+ *,
526
+ child_a: ForwardBatch,
527
+ child_b: ForwardBatch,
528
+ tbo_split_seq_index: int,
529
+ ):
530
+ extend_seq_lens_cpu = batch.extend_seq_lens_cpu
531
+ overall_seq_lens_sum = sum(extend_seq_lens_cpu)
532
+ half_seq_lens_sum = overall_seq_lens_sum // 2
533
+ left_last_seq_token_num = half_seq_lens_sum - sum(
534
+ extend_seq_lens_cpu[:tbo_split_seq_index]
535
+ )
536
+ right_first_seq_token_num = (
537
+ extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
538
+ )
539
+
540
+ # making deepcopy to be extra safe
541
+ child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
542
+ child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
543
+ child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
544
+ child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
545
+ for child in [child_a, child_b]:
546
+ _update_device_and_sum_field_from_cpu_field(
547
+ batch=child,
548
+ cpu_field="extend_seq_lens_cpu",
549
+ device_field="extend_seq_lens",
550
+ sum_field="extend_num_tokens",
551
+ )
552
+
553
+ assert (
554
+ child_a.extend_num_tokens == half_seq_lens_sum
555
+ ), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
556
+
557
+ child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
558
+ child_a.seq_lens_cpu[-1] = (
559
+ child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
560
+ )
561
+ _update_device_and_sum_field_from_cpu_field(
562
+ batch=child_a,
563
+ cpu_field="seq_lens_cpu",
564
+ device_field="seq_lens",
565
+ sum_field="seq_lens_sum",
566
+ )
567
+
568
+ child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
569
+ child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
570
+ _update_device_and_sum_field_from_cpu_field(
571
+ batch=child_b,
572
+ cpu_field="extend_prefix_lens_cpu",
573
+ device_field="extend_prefix_lens",
574
+ sum_field=None,
575
+ )
576
+ _, child_b.extend_start_loc = compute_position(
577
+ global_server_args_dict["attention_backend"],
578
+ child_b.extend_prefix_lens,
579
+ child_b.extend_seq_lens,
580
+ child_b.extend_num_tokens,
581
+ )
582
+
430
583
  @classmethod
431
584
  def filter_batch(
432
585
  cls,
@@ -468,7 +621,7 @@ class TboForwardBatchPreparer:
468
621
  "extend_prefix_lens_cpu",
469
622
  "extend_seq_lens_cpu",
470
623
  "extend_logprob_start_lens_cpu",
471
- "lora_paths",
624
+ "lora_ids",
472
625
  ]:
473
626
  old_value = getattr(batch, key)
474
627
  if old_value is None:
@@ -510,6 +663,7 @@ class TboForwardBatchPreparer:
510
663
  "padded_static_len",
511
664
  "mrope_positions", # only used by qwen2-vl, thus not care
512
665
  "split_index", # for split prefill
666
+ "orig_seq_lens", # only used by qwen-1m, thus not care
513
667
  ]:
514
668
  output_dict[key] = getattr(batch, key)
515
669
  if not batch.forward_mode.is_target_verify():
@@ -670,9 +824,15 @@ def _model_forward_tbo(
670
824
  )
671
825
  del inputs
672
826
 
673
- with deep_gemm_wrapper.configure_deep_gemm_num_sms(
674
- operations_strategy.deep_gemm_num_sms
675
- ):
827
+ context = (
828
+ empty_context()
829
+ if _is_hip
830
+ else deep_gemm_wrapper.configure_deep_gemm_num_sms(
831
+ operations_strategy.deep_gemm_num_sms
832
+ )
833
+ )
834
+
835
+ with context:
676
836
  outputs_arr = execute_overlapped_operations(
677
837
  inputs_arr=inputs_arr,
678
838
  operations_arr=[operations_strategy.operations] * 2,
sglang/srt/utils.py CHANGED
@@ -41,9 +41,11 @@ import tempfile
41
41
  import threading
42
42
  import time
43
43
  import traceback
44
+ import uuid
44
45
  import warnings
45
46
  from collections import OrderedDict, defaultdict
46
47
  from contextlib import contextmanager
48
+ from dataclasses import dataclass
47
49
  from functools import lru_cache
48
50
  from importlib.metadata import PackageNotFoundError, version
49
51
  from importlib.util import find_spec
@@ -84,6 +86,7 @@ from torch.library import Library
84
86
  from torch.profiler import ProfilerActivity, profile, record_function
85
87
  from torch.utils._contextlib import _DecoratorContextManager
86
88
  from triton.runtime.cache import FileCacheManager
89
+ from typing_extensions import Literal
87
90
 
88
91
  from sglang.srt.metrics.func_timer import enable_func_timer
89
92
 
@@ -231,6 +234,10 @@ def is_flashinfer_available():
231
234
  return importlib.util.find_spec("flashinfer") is not None and is_cuda()
232
235
 
233
236
 
237
+ def random_uuid() -> str:
238
+ return str(uuid.uuid4().hex)
239
+
240
+
234
241
  _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
235
242
  "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
236
243
  )
@@ -736,9 +743,18 @@ def load_audio(
736
743
  return audio
737
744
 
738
745
 
746
+ @dataclass
747
+ class ImageData:
748
+ url: str
749
+ detail: Optional[Literal["auto", "low", "high"]] = "auto"
750
+
751
+
739
752
  def load_image(
740
- image_file: Union[Image.Image, str, bytes],
753
+ image_file: Union[Image.Image, str, ImageData, bytes],
741
754
  ) -> tuple[Image.Image, tuple[int, int]]:
755
+ if isinstance(image_file, ImageData):
756
+ image_file = image_file.url
757
+
742
758
  image = image_size = None
743
759
  if isinstance(image_file, Image.Image):
744
760
  image = image_file
@@ -762,7 +778,7 @@ def load_image(
762
778
  elif isinstance(image_file, str):
763
779
  image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
764
780
  else:
765
- raise ValueError(f"Invalid image: {image}")
781
+ raise ValueError(f"Invalid image: {image_file}")
766
782
 
767
783
  return image, image_size
768
784
 
@@ -799,7 +815,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
799
815
  vr = VideoReader(tmp_file.name, ctx=ctx)
800
816
  elif video_file.startswith("data:"):
801
817
  _, encoded = video_file.split(",", 1)
802
- video_bytes = base64.b64decode(encoded)
818
+ video_bytes = pybase64.b64decode(encoded)
803
819
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
804
820
  tmp_file.write(video_bytes)
805
821
  tmp_file.close()
@@ -807,7 +823,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
807
823
  elif os.path.isfile(video_file):
808
824
  vr = VideoReader(video_file, ctx=ctx)
809
825
  else:
810
- video_bytes = base64.b64decode(video_file)
826
+ video_bytes = pybase64.b64decode(video_file)
811
827
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
812
828
  tmp_file.write(video_bytes)
813
829
  tmp_file.close()
@@ -2113,6 +2129,10 @@ def next_power_of_2(n: int):
2113
2129
  return 1 << (n - 1).bit_length() if n > 0 else 1
2114
2130
 
2115
2131
 
2132
+ def round_up(x: int, y: int) -> int:
2133
+ return ((x - 1) // y + 1) * y
2134
+
2135
+
2116
2136
  setattr(triton, "next_power_of_2", next_power_of_2)
2117
2137
 
2118
2138
 
@@ -2832,6 +2852,17 @@ def parse_module_path(module_path, function_name, create_dummy):
2832
2852
  return final_module, None
2833
2853
 
2834
2854
 
2855
+ def mxfp_supported():
2856
+ """
2857
+ Returns whether the current platform supports MX types.
2858
+ """
2859
+ if torch.version.hip:
2860
+ gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
2861
+ return any(gfx in gcn_arch for gfx in ["gfx95"])
2862
+ else:
2863
+ return False
2864
+
2865
+
2835
2866
  # LoRA-related constants and utilities
2836
2867
  SUPPORTED_LORA_TARGET_MODULES = [
2837
2868
  "q_proj",
@@ -2929,4 +2960,9 @@ class ConcurrentCounter:
2929
2960
  This suspends the calling coroutine without blocking the thread, allowing
2930
2961
  other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
2931
2962
  """
2932
- self.wait_for(lambda count: count == 0)
2963
+ await self.wait_for(lambda count: count == 0)
2964
+
2965
+
2966
+ @lru_cache(maxsize=1)
2967
+ def is_triton_kernels_available() -> bool:
2968
+ return importlib.util.find_spec("triton_kernels") is not None
@@ -0,0 +1,106 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass
8
+ class FlattenedTensorMetadata:
9
+ """Metadata for a tensor in a flattened bucket"""
10
+
11
+ name: str
12
+ shape: torch.Size
13
+ dtype: torch.dtype
14
+ start_idx: int
15
+ end_idx: int
16
+ numel: int
17
+
18
+
19
+ class FlattenedTensorBucket:
20
+ """
21
+ A bucket that flattens multiple tensors into a single tensor for efficient processing
22
+ while preserving all metadata needed for reconstruction.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ named_tensors: List[Tuple[str, torch.Tensor]] = None,
28
+ flattened_tensor: torch.Tensor = None,
29
+ metadata: List[FlattenedTensorMetadata] = None,
30
+ ):
31
+ """
32
+ Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
33
+ Args:
34
+ named_tensors: List of (name, tensor) tuples (for creating new bucket)
35
+ flattened_tensor: Pre-flattened tensor (for reconstruction)
36
+ metadata: Pre-computed metadata (for reconstruction)
37
+ """
38
+ if named_tensors is not None:
39
+ # Create bucket from named tensors
40
+ self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
41
+ self.flattened_tensor: torch.Tensor = None
42
+
43
+ if not named_tensors:
44
+ raise ValueError("Cannot create empty tensor bucket")
45
+
46
+ # Collect metadata and flatten tensors
47
+ current_idx = 0
48
+ flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
49
+
50
+ for i, (name, tensor) in enumerate(named_tensors):
51
+ flattened = tensor.flatten()
52
+ flattened_tensors[i] = flattened
53
+
54
+ # Store metadata
55
+
56
+ numel = flattened.numel()
57
+ metadata_obj = FlattenedTensorMetadata(
58
+ name=name,
59
+ shape=tensor.shape,
60
+ dtype=tensor.dtype,
61
+ start_idx=current_idx,
62
+ end_idx=current_idx + numel,
63
+ numel=numel,
64
+ )
65
+ self.metadata[i] = metadata_obj
66
+ current_idx += numel
67
+
68
+ # Concatenate all flattened tensors
69
+ self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
70
+ else:
71
+ # Initialize from pre-flattened data
72
+ if flattened_tensor is None or metadata is None:
73
+ raise ValueError(
74
+ "Must provide either named_tensors or both flattened_tensor and metadata"
75
+ )
76
+ self.flattened_tensor = flattened_tensor
77
+ self.metadata = metadata
78
+
79
+ def get_flattened_tensor(self) -> torch.Tensor:
80
+ """Get the flattened tensor containing all bucket tensors"""
81
+ return self.flattened_tensor
82
+
83
+ def get_metadata(self) -> List[FlattenedTensorMetadata]:
84
+ """Get metadata for all tensors in the bucket"""
85
+ return self.metadata
86
+
87
+ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
88
+ """
89
+ Reconstruct original tensors from flattened tensor with optimized performance.
90
+ Uses memory-efficient operations to minimize allocations and copies.
91
+ """
92
+ # preallocate the result list
93
+ reconstructed = [None] * len(self.metadata)
94
+
95
+ for i, meta in enumerate(self.metadata):
96
+ tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape(
97
+ meta.shape
98
+ )
99
+
100
+ # batch dtype conversion (if needed)
101
+ if tensor.dtype != meta.dtype:
102
+ tensor = tensor.to(meta.dtype)
103
+
104
+ reconstructed[i] = (meta.name, tensor)
105
+
106
+ return reconstructed