sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
1
+ # Adapted from
2
+ # https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py
3
+
4
+ from functools import partial
5
+ from typing import Optional, Type, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import SiglipVisionConfig
10
+
11
+ from sglang.srt.layers.activation import QuickGELU
12
+ from sglang.srt.layers.attention.vision import VisionAttention
13
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
14
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
15
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
16
+ from sglang.srt.utils import add_prefix
17
+
18
+
19
+ # Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
20
+ class SiglipVisionEmbeddings(nn.Module):
21
+
22
+ def __init__(self, config: SiglipVisionConfig):
23
+ super().__init__()
24
+ self.config = config
25
+ self.embed_dim = config.hidden_size
26
+ self.image_size = config.image_size
27
+ self.patch_size = config.patch_size
28
+
29
+ self.patch_embedding = nn.Conv2d(
30
+ in_channels=config.num_channels,
31
+ out_channels=self.embed_dim,
32
+ kernel_size=self.patch_size,
33
+ stride=self.patch_size,
34
+ padding="valid",
35
+ )
36
+
37
+ self.num_patches = (self.image_size // self.patch_size) ** 2
38
+ self.num_positions = self.num_patches
39
+ self.position_embedding = VocabParallelEmbedding(
40
+ self.num_positions, self.embed_dim
41
+ )
42
+ self.register_buffer(
43
+ "position_ids",
44
+ torch.arange(self.num_positions).expand((1, -1)),
45
+ persistent=False,
46
+ )
47
+
48
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
49
+ target_dtype = self.patch_embedding.weight.dtype
50
+ patch_embeds = self.patch_embedding(
51
+ pixel_values.to(dtype=target_dtype)
52
+ ) # shape = [*, width, grid, grid]
53
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
54
+ # interpolate_pos_encoding is never used in sglang
55
+ embeddings = embeddings + self.position_embedding(self.position_ids)
56
+
57
+ return embeddings
58
+
59
+
60
+ # Copied from sglang.srt.models.clip.CLIPMLP
61
+ class SiglipMLP(nn.Module):
62
+
63
+ def __init__(
64
+ self,
65
+ config,
66
+ act_layer: Type[nn.Module] = QuickGELU,
67
+ quant_config: Optional[QuantizationConfig] = None,
68
+ prefix: str = "",
69
+ ):
70
+ super().__init__()
71
+ self.fc1 = ColumnParallelLinear(
72
+ config.hidden_size,
73
+ config.intermediate_size,
74
+ quant_config=quant_config,
75
+ prefix=add_prefix("fc1", prefix),
76
+ )
77
+ self.act = act_layer()
78
+ self.fc2 = RowParallelLinear(
79
+ config.intermediate_size,
80
+ config.hidden_size,
81
+ quant_config=quant_config,
82
+ prefix=add_prefix("fc2", prefix),
83
+ )
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ x_parallel, _ = self.fc1(x)
87
+ x_parallel = self.act(x_parallel)
88
+ x, _ = self.fc2(x_parallel)
89
+ return x
90
+
91
+
92
+ # Copied from sglang.srt.models.clip.CLIPEncoderLayer
93
+ class SiglipEncoderLayer(nn.Module):
94
+
95
+ def __init__(
96
+ self,
97
+ config: SiglipVisionConfig,
98
+ act_layer: Type[nn.Module] = QuickGELU,
99
+ norm_layer: Type[nn.Module] = None,
100
+ attn_implementation: Optional[str] = "sdpa",
101
+ quant_config: Optional[QuantizationConfig] = None,
102
+ prefix: str = "",
103
+ ) -> None:
104
+ super().__init__()
105
+ if norm_layer is None:
106
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
107
+ self.layer_norm1 = norm_layer(config.hidden_size)
108
+ self.layer_norm2 = norm_layer(config.hidden_size)
109
+ if attn_implementation == "sdpa":
110
+ qkv_backend = "sdpa"
111
+ softmax_in_single_precision = False
112
+ elif attn_implementation == "flash_attention_2":
113
+ qkv_backend = "triton_attn"
114
+ softmax_in_single_precision = False
115
+ elif attn_implementation == "eager":
116
+ qkv_backend = "sdpa"
117
+ softmax_in_single_precision = True
118
+ self.self_attn = VisionAttention(
119
+ embed_dim=config.hidden_size,
120
+ num_heads=config.num_attention_heads,
121
+ projection_size=config.hidden_size,
122
+ use_qkv_parallel=True,
123
+ qkv_backend=qkv_backend,
124
+ softmax_in_single_precision=softmax_in_single_precision,
125
+ flatten_batch=True,
126
+ quant_config=quant_config,
127
+ prefix=add_prefix("self_attn", prefix),
128
+ )
129
+ self.mlp = SiglipMLP(
130
+ config,
131
+ act_layer=act_layer,
132
+ quant_config=quant_config,
133
+ prefix=add_prefix("mlp", prefix),
134
+ )
135
+
136
+ def forward(
137
+ self,
138
+ hidden_states: torch.Tensor,
139
+ attention_mask: torch.Tensor,
140
+ causal_attention_mask: torch.Tensor,
141
+ ) -> torch.Tensor:
142
+
143
+ residual = hidden_states
144
+ hidden_states = self.layer_norm1(hidden_states)
145
+ # Siglip text model uses both `causal_attention_mask` and `attention_mask`
146
+ if attention_mask is not None and causal_attention_mask is not None:
147
+ attn_mask = attention_mask + causal_attention_mask
148
+ elif causal_attention_mask is not None:
149
+ attn_mask = causal_attention_mask
150
+ else:
151
+ attn_mask = attention_mask
152
+ hidden_states = self.self_attn(
153
+ hidden_states,
154
+ attention_mask=attn_mask,
155
+ # causal_attention_mask=causal_attention_mask,
156
+ )
157
+
158
+ hidden_states = residual + hidden_states
159
+ residual = hidden_states
160
+ hidden_states = self.layer_norm2(hidden_states)
161
+ hidden_states = self.mlp(hidden_states)
162
+ hidden_states = residual + hidden_states
163
+ return hidden_states
164
+
165
+
166
+ # Copied from sglang.srt.models.clip.CLIPEncoder
167
+ class SiglipEncoder(nn.Module):
168
+ """
169
+ Transformer encoder consisting of `config.num_hidden_layers` self
170
+ attention layers. Each layer is a [`SiglipEncoderLayer`].
171
+
172
+ Args:
173
+ config: SiglipConfig
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ config: SiglipVisionConfig,
179
+ quant_config: Optional[QuantizationConfig] = None,
180
+ prefix: str = "",
181
+ ) -> None:
182
+ super().__init__()
183
+
184
+ self.config = config
185
+
186
+ num_hidden_layers = config.num_hidden_layers
187
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
188
+ self.layers = nn.ModuleList(
189
+ [
190
+ SiglipEncoderLayer(
191
+ config=config,
192
+ norm_layer=norm_layer,
193
+ attn_implementation="sdpa",
194
+ quant_config=quant_config,
195
+ prefix=add_prefix(f"layers.{layer_idx}", prefix),
196
+ )
197
+ for layer_idx in range(num_hidden_layers)
198
+ ]
199
+ )
200
+
201
+ def forward(
202
+ self,
203
+ inputs_embeds: torch.Tensor,
204
+ attention_mask: torch.Tensor = None,
205
+ causal_attention_mask: torch.Tensor = None,
206
+ return_all_hidden_states: bool = False,
207
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
208
+ hidden_states_pool = [inputs_embeds]
209
+ hidden_states = inputs_embeds
210
+
211
+ for encoder_layer in self.layers:
212
+ hidden_states = encoder_layer(
213
+ hidden_states, attention_mask, causal_attention_mask
214
+ )
215
+ if return_all_hidden_states:
216
+ hidden_states_pool.append(hidden_states)
217
+ if return_all_hidden_states:
218
+ return hidden_states_pool
219
+ return hidden_states
220
+
221
+
222
+ # Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
223
+ class SiglipVisionTransformer(nn.Module):
224
+
225
+ def __init__(
226
+ self,
227
+ config: SiglipVisionConfig,
228
+ quant_config: Optional[QuantizationConfig] = None,
229
+ prefix: str = "",
230
+ ) -> None:
231
+ super().__init__()
232
+
233
+ self.config = config
234
+ embed_dim = config.hidden_size
235
+
236
+ self.embeddings = SiglipVisionEmbeddings(config)
237
+
238
+ self.encoder = SiglipEncoder(
239
+ config=config,
240
+ quant_config=quant_config,
241
+ prefix=add_prefix("encoder", prefix),
242
+ )
243
+
244
+ num_hidden_layers = config.num_hidden_layers
245
+ if len(self.encoder.layers) > config.num_hidden_layers:
246
+ raise ValueError(
247
+ f"The original encoder only has {num_hidden_layers} "
248
+ f"layers, but you requested {len(self.encoder.layers)} layers."
249
+ )
250
+
251
+ # VisionAttention in SiglipEncoderLayer is multihead attention
252
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
253
+
254
+ @property
255
+ def device(self) -> torch.device:
256
+ return self.encoder.layers[0].layer_norm1.weight.device
257
+
258
+ def forward(
259
+ self,
260
+ pixel_values: torch.Tensor,
261
+ ) -> torch.Tensor:
262
+ hidden_states = self.embeddings(pixel_values.to(self.device))
263
+
264
+ return_all_hidden_states = False
265
+
266
+ last_hidden_state = self.encoder(
267
+ inputs_embeds=hidden_states,
268
+ return_all_hidden_states=return_all_hidden_states,
269
+ )
270
+
271
+ last_hidden_state = self.post_layernorm(last_hidden_state)
272
+
273
+ return last_hidden_state
274
+
275
+
276
+ # Copied from sglang.srt.models.clip.CLIPVisionModel
277
+ class SiglipVisionModel(nn.Module):
278
+ def __init__(
279
+ self,
280
+ config: SiglipVisionConfig,
281
+ quant_config: Optional[QuantizationConfig] = None,
282
+ prefix: str = "",
283
+ ):
284
+ super().__init__()
285
+ self.vision_model = SiglipVisionTransformer(
286
+ config, quant_config, prefix=add_prefix("vision_model", prefix)
287
+ )
288
+
289
+ @property
290
+ def device(self) -> torch.device:
291
+ return self.vision_model.device
292
+
293
+ def forward(self, pixel_values: torch.Tensor):
294
+ return self.vision_model(pixel_values)
@@ -40,7 +40,7 @@ from sglang.srt.conversation import (
40
40
  get_conv_template_by_model_path,
41
41
  register_conv_template,
42
42
  )
43
- from sglang.srt.function_call_parser import FunctionCallParser
43
+ from sglang.srt.function_call.function_call_parser import FunctionCallParser
44
44
  from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
45
45
  from sglang.srt.openai_api.protocol import (
46
46
  BatchRequest,
@@ -970,7 +970,7 @@ def v1_chat_generate_request(
970
970
  # - image_data: None or a list of image strings (URLs or base64 strings).
971
971
  # - audio_data: None or a list of audio strings (URLs).
972
972
  # None skips any image processing in GenerateReqInput.
973
- strict_tag = None
973
+ tool_call_constraint = None
974
974
  prompt = ""
975
975
  prompt_ids = []
976
976
  if not isinstance(request.messages, str):
@@ -989,7 +989,9 @@ def v1_chat_generate_request(
989
989
 
990
990
  tool_call_parser = tokenizer_manager.server_args.tool_call_parser
991
991
  parser = FunctionCallParser(request.tools, tool_call_parser)
992
- strict_tag = parser.get_structure_tag()
992
+ tool_call_constraint = parser.get_structure_constraint(
993
+ request.tool_choice
994
+ )
993
995
 
994
996
  if chat_template_name is None:
995
997
  openai_compatible_messages = []
@@ -1156,20 +1158,24 @@ def v1_chat_generate_request(
1156
1158
  request.response_format.model_dump(by_alias=True)
1157
1159
  )
1158
1160
 
1159
- if strict_tag is not None:
1160
- if (
1161
- sampling_params.get("regex")
1162
- or sampling_params.get("ebnf")
1163
- or sampling_params.get("structural_tag")
1164
- or sampling_params.get("json_schema")
1165
- ):
1166
- logger.warning(
1167
- "Constrained decoding is not compatible with tool calls."
1161
+ # Check if there are already existing output constraints
1162
+ has_existing_constraints = (
1163
+ sampling_params.get("regex")
1164
+ or sampling_params.get("ebnf")
1165
+ or sampling_params.get("structural_tag")
1166
+ or sampling_params.get("json_schema")
1167
+ )
1168
+
1169
+ if tool_call_constraint and has_existing_constraints:
1170
+ logger.warning("Constrained decoding is not compatible with tool calls.")
1171
+ elif tool_call_constraint:
1172
+ constraint_type, constraint_value = tool_call_constraint
1173
+ if constraint_type == "structural_tag":
1174
+ sampling_params[constraint_type] = convert_json_schema_to_str(
1175
+ constraint_value.model_dump(by_alias=True)
1168
1176
  )
1169
1177
  else:
1170
- sampling_params["structural_tag"] = convert_json_schema_to_str(
1171
- strict_tag.model_dump(by_alias=True)
1172
- )
1178
+ sampling_params[constraint_type] = constraint_value
1173
1179
 
1174
1180
  sampling_params_list.append(sampling_params)
1175
1181
 
@@ -1193,6 +1199,7 @@ def v1_chat_generate_request(
1193
1199
  top_logprobs_nums = top_logprobs_nums[0]
1194
1200
  modalities_list = modalities_list[0]
1195
1201
  lora_paths = lora_paths[0]
1202
+ request_ids = request_ids[0]
1196
1203
  else:
1197
1204
  if tokenizer_manager.model_config.is_multimodal:
1198
1205
  # processor will need text input
@@ -1429,7 +1436,9 @@ async def v1_chat_completions(
1429
1436
  return create_error_response("Invalid request body, error: ", str(e))
1430
1437
  all_requests = [ChatCompletionRequest(**request_json)]
1431
1438
  created = int(time.time())
1432
- adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1439
+ adapted_request, request = v1_chat_generate_request(
1440
+ all_requests, tokenizer_manager, request_ids=[all_requests[0].rid]
1441
+ )
1433
1442
 
1434
1443
  if adapted_request.stream:
1435
1444
  parser_dict = {}
@@ -1812,6 +1821,7 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1812
1821
  prompt_kwargs = {"text": generate_prompts, "image_data": images}
1813
1822
  else:
1814
1823
  prompt_kwargs = {"input_ids": prompt}
1824
+ request_ids = all_requests[0].rid
1815
1825
  else:
1816
1826
  if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
1817
1827
  prompt_kwargs = {"text": prompts}
@@ -1824,8 +1834,10 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1824
1834
  )
1825
1835
  else:
1826
1836
  prompt_kwargs = {"input_ids": prompts}
1837
+ request_ids = [req.rid for req in all_requests]
1827
1838
 
1828
1839
  adapted_request = EmbeddingReqInput(
1840
+ rid=request_ids,
1829
1841
  **prompt_kwargs,
1830
1842
  )
1831
1843
 
@@ -392,6 +392,9 @@ class ChatCompletionRequest(BaseModel):
392
392
  stream_reasoning: bool = True
393
393
  chat_template_kwargs: Optional[Dict] = None
394
394
 
395
+ # The request id.
396
+ rid: Optional[str] = None
397
+
395
398
  # For PD disaggregation
396
399
  bootstrap_host: Optional[str] = None
397
400
  bootstrap_port: Optional[int] = None
@@ -466,6 +469,9 @@ class EmbeddingRequest(BaseModel):
466
469
  dimensions: int = None
467
470
  user: Optional[str] = None
468
471
 
472
+ # The request id.
473
+ rid: Optional[str] = None
474
+
469
475
 
470
476
  class EmbeddingObject(BaseModel):
471
477
  embedding: List[float]
@@ -0,0 +1,154 @@
1
+ import os
2
+ from contextlib import contextmanager
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, Generator, List, Sequence, Union
5
+
6
+ import torch
7
+
8
+ _ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
9
+
10
+ if _ENABLE_PROFILE:
11
+ import nvtx
12
+
13
+
14
+ def execute_operations(inputs, operations):
15
+ stages = _convert_operations_to_stages(decorate_operations(operations))
16
+ executor = _StageExecutor("primary", stages, inputs=inputs)
17
+ for _ in range(executor.num_stages):
18
+ executor.next()
19
+ assert executor.done
20
+ return executor.output
21
+
22
+
23
+ class YieldOperation:
24
+ pass
25
+
26
+
27
+ @dataclass
28
+ class ExecutionOperation:
29
+ debug_name: str
30
+ fn: Callable
31
+
32
+
33
+ Operation = Union[YieldOperation, ExecutionOperation, Callable]
34
+ Stage = List[ExecutionOperation]
35
+
36
+
37
+ class _StageExecutor:
38
+ def __init__(self, debug_name: str, stages: List[Stage], inputs):
39
+ self._debug_name = debug_name
40
+ self._stages = stages
41
+ self._index = 0
42
+ self._stage_state = _StateDict()
43
+ self._stage_output = inputs
44
+
45
+ def next(self):
46
+ assert not self.done
47
+
48
+ stage = self._stages[self._index]
49
+
50
+ with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
51
+ for op in stage:
52
+ with _annotate_region(debug_name=op.debug_name):
53
+ self._stage_output = op.fn(
54
+ state=self._stage_state,
55
+ **(
56
+ self._stage_output if self._stage_output is not None else {}
57
+ ),
58
+ )
59
+
60
+ self._index += 1
61
+
62
+ @property
63
+ def output(self):
64
+ assert self.done
65
+ return self._stage_output
66
+
67
+ @property
68
+ def done(self):
69
+ return self._index >= self.num_stages
70
+
71
+ @property
72
+ def num_stages(self):
73
+ return len(self._stages)
74
+
75
+
76
+ @contextmanager
77
+ def _annotate_region(debug_name):
78
+ if _ENABLE_PROFILE:
79
+ with torch.autograd.profiler.record_function(debug_name):
80
+ with nvtx.annotate(debug_name):
81
+ yield
82
+ else:
83
+ yield
84
+
85
+
86
+ class _StateDict:
87
+ def __init__(self):
88
+ self._data = {}
89
+
90
+ def __setattr__(self, key, value):
91
+ if key == "_data":
92
+ super().__setattr__(key, value)
93
+ return
94
+ assert (
95
+ key not in self._data
96
+ ), f"`{key}` already exist, are you sure you want to override it?"
97
+ self._data[key] = value
98
+
99
+ def __getattr__(self, item):
100
+ return self._data[item]
101
+
102
+ def __delattr__(self, item):
103
+ del self._data[item]
104
+
105
+ def pop(self, item):
106
+ return self._data.pop(item)
107
+
108
+ def update(self, values: Dict[str, Any]):
109
+ for k, v in values.items():
110
+ setattr(self, k, v)
111
+
112
+ def clear(self, expect_keys: Sequence[str]):
113
+ if set(self._data.keys()) != set(expect_keys):
114
+ raise Exception(
115
+ f"Unexpected keys when clearning. This may indicate you do not release memory early enough but leave it to here. {list(self._data.keys())=} {expect_keys=}"
116
+ )
117
+
118
+ self._data.clear()
119
+
120
+
121
+ def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
122
+ operation_chunks = list(
123
+ _chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
124
+ )
125
+ assert all(len(chunk) > 0 for chunk in operation_chunks)
126
+ return operation_chunks
127
+
128
+
129
+ def _chunk_by_separator(
130
+ items: List[Any], is_separator: Callable[[Any], bool]
131
+ ) -> Generator[List[Any], None, None]:
132
+ pending_items = []
133
+ for item in items:
134
+ if is_separator(item):
135
+ yield pending_items
136
+ pending_items = []
137
+ else:
138
+ pending_items.append(item)
139
+ if len(pending_items) > 0:
140
+ yield pending_items
141
+
142
+
143
+ def decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
144
+ return [_decorate_operation(op, debug_name_prefix) for op in operations]
145
+
146
+
147
+ def _decorate_operation(operation: Operation, debug_name_prefix: str):
148
+ if isinstance(operation, YieldOperation):
149
+ return operation
150
+ return ExecutionOperation(
151
+ debug_name=debug_name_prefix
152
+ + getattr(operation, "__name__", "unknown").replace("op_", ""),
153
+ fn=operation,
154
+ )
@@ -0,0 +1,31 @@
1
+ import torch
2
+
3
+
4
+ def compute_layer_operations(
5
+ layer: torch.nn.Module,
6
+ ):
7
+ if not layer.is_layer_sparse:
8
+ return [
9
+ layer.op_comm_prepare_attn,
10
+ layer.op_attn,
11
+ layer.op_comm_prepare_mlp,
12
+ layer.op_mlp,
13
+ layer.op_comm_postprocess_layer,
14
+ ]
15
+
16
+ # Will add TBO operation orders here
17
+ return [
18
+ layer.op_comm_prepare_attn,
19
+ layer.op_attn,
20
+ layer.op_comm_prepare_mlp,
21
+ layer.mlp.op_gate,
22
+ layer.mlp.op_shared_experts,
23
+ layer.mlp.op_select_experts,
24
+ layer.mlp.op_dispatch_a,
25
+ layer.mlp.op_dispatch_b,
26
+ layer.mlp.op_experts,
27
+ layer.mlp.op_combine_a,
28
+ layer.mlp.op_combine_b,
29
+ layer.mlp.op_output,
30
+ layer.op_comm_postprocess_layer,
31
+ ]