sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -42,11 +42,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
44
  from sglang.srt.managers.schedule_batch import global_server_args_dict
45
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
- from sglang.srt.utils import is_hip
45
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
+ from sglang.srt.utils import is_flashinfer_available
47
47
 
48
- # ROCm: flashinfer available later
49
- if not is_hip():
48
+ if is_flashinfer_available():
50
49
  from flashinfer import bmm_fp8
51
50
 
52
51
 
@@ -193,7 +192,7 @@ class MiniCPM3Attention(nn.Module):
193
192
  self,
194
193
  positions: torch.Tensor,
195
194
  hidden_states: torch.Tensor,
196
- input_metadata: InputMetadata,
195
+ forward_batch: ForwardBatch,
197
196
  ) -> torch.Tensor:
198
197
  if self.q_lora_rank is not None:
199
198
  q = self.q_a_proj(hidden_states)[0]
@@ -230,7 +229,7 @@ class MiniCPM3Attention(nn.Module):
230
229
  v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
231
230
  -1, self.num_local_heads * 128
232
231
  )
233
- attn_output = self.attn(q, k, v, input_metadata)
232
+ attn_output = self.attn(q, k, v, forward_batch)
234
233
  attn_output = attn_output.view(-1, self.num_local_heads, 128)[
235
234
  ..., : self.v_head_dim
236
235
  ].reshape(-1, self.num_local_heads * self.v_head_dim)
@@ -341,7 +340,7 @@ class MiniCPM3AttentionMLA(nn.Module):
341
340
  self,
342
341
  positions: torch.Tensor,
343
342
  hidden_states: torch.Tensor,
344
- input_metadata: InputMetadata,
343
+ forward_batch: ForwardBatch,
345
344
  ) -> torch.Tensor:
346
345
  q_len = hidden_states.shape[0]
347
346
  q_input = hidden_states.new_empty(
@@ -383,7 +382,7 @@ class MiniCPM3AttentionMLA(nn.Module):
383
382
  q_input[..., self.kv_lora_rank :] = q_pe
384
383
  k_input[..., self.kv_lora_rank :] = k_pe
385
384
 
386
- attn_output = self.attn(q_input, k_input, v_input, input_metadata)
385
+ attn_output = self.attn(q_input, k_input, v_input, forward_batch)
387
386
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
388
387
 
389
388
  if self.w_vc.dtype == torch.float8_e4m3fn:
@@ -472,7 +471,7 @@ class MiniCPM3DecoderLayer(nn.Module):
472
471
  self,
473
472
  positions: torch.Tensor,
474
473
  hidden_states: torch.Tensor,
475
- input_metadata: InputMetadata,
474
+ forward_batch: ForwardBatch,
476
475
  residual: Optional[torch.Tensor],
477
476
  ) -> Tuple[torch.Tensor, torch.Tensor]:
478
477
  # Self Attention
@@ -481,7 +480,7 @@ class MiniCPM3DecoderLayer(nn.Module):
481
480
  hidden_states = self.self_attn(
482
481
  positions=positions,
483
482
  hidden_states=hidden_states,
484
- input_metadata=input_metadata,
483
+ forward_batch=forward_batch,
485
484
  )
486
485
  hidden_states = residual + hidden_states * (
487
486
  self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
@@ -528,7 +527,7 @@ class MiniCPM3Model(nn.Module):
528
527
  self,
529
528
  input_ids: torch.Tensor,
530
529
  positions: torch.Tensor,
531
- input_metadata: InputMetadata,
530
+ forward_batch: ForwardBatch,
532
531
  input_embeds: torch.Tensor = None,
533
532
  ) -> torch.Tensor:
534
533
  if input_embeds is None:
@@ -542,7 +541,7 @@ class MiniCPM3Model(nn.Module):
542
541
  hidden_states, residual = layer(
543
542
  positions,
544
543
  hidden_states,
545
- input_metadata,
544
+ forward_batch,
546
545
  residual,
547
546
  )
548
547
  hidden_states = self.norm(hidden_states)
@@ -581,19 +580,19 @@ class MiniCPM3ForCausalLM(nn.Module):
581
580
  self,
582
581
  input_ids: torch.Tensor,
583
582
  positions: torch.Tensor,
584
- input_metadata: InputMetadata,
583
+ forward_batch: ForwardBatch,
585
584
  input_embeds: torch.Tensor = None,
586
585
  ) -> torch.Tensor:
587
586
  if input_embeds is not None:
588
587
  input_embeds = input_embeds * self.config.scale_emb
589
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
588
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
590
589
  hidden_states = hidden_states / self.scale_width
591
590
  if self.config.tie_word_embeddings:
592
591
  lm_head_weight = self.model.embed_tokens.weight
593
592
  else:
594
593
  lm_head_weight = self.lm_head.weight
595
594
  return self.logits_processor(
596
- input_ids, hidden_states, lm_head_weight, input_metadata
595
+ input_ids, hidden_states, lm_head_weight, forward_batch
597
596
  )
598
597
 
599
598
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
44
  from sglang.srt.layers.torchao_utils import apply_torchao_config_
45
45
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
 
48
48
 
49
49
  class MixtralMoE(nn.Module):
@@ -171,12 +171,12 @@ class MixtralAttention(nn.Module):
171
171
  self,
172
172
  positions: torch.Tensor,
173
173
  hidden_states: torch.Tensor,
174
- input_metadata: InputMetadata,
174
+ forward_batch: ForwardBatch,
175
175
  ) -> torch.Tensor:
176
176
  qkv, _ = self.qkv_proj(hidden_states)
177
177
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
178
178
  q, k = self.rotary_emb(positions, q, k)
179
- attn_output = self.attn(q, k, v, input_metadata)
179
+ attn_output = self.attn(q, k, v, forward_batch)
180
180
  output, _ = self.o_proj(attn_output)
181
181
  return output
182
182
 
@@ -220,7 +220,7 @@ class MixtralDecoderLayer(nn.Module):
220
220
  self,
221
221
  positions: torch.Tensor,
222
222
  hidden_states: torch.Tensor,
223
- input_metadata: InputMetadata,
223
+ forward_batch: ForwardBatch,
224
224
  residual: Optional[torch.Tensor],
225
225
  ) -> torch.Tensor:
226
226
  # Self Attention
@@ -232,7 +232,7 @@ class MixtralDecoderLayer(nn.Module):
232
232
  hidden_states = self.self_attn(
233
233
  positions=positions,
234
234
  hidden_states=hidden_states,
235
- input_metadata=input_metadata,
235
+ forward_batch=forward_batch,
236
236
  )
237
237
 
238
238
  # Fully Connected
@@ -270,7 +270,7 @@ class MixtralModel(nn.Module):
270
270
  self,
271
271
  input_ids: torch.Tensor,
272
272
  positions: torch.Tensor,
273
- input_metadata: InputMetadata,
273
+ forward_batch: ForwardBatch,
274
274
  input_embeds: torch.Tensor = None,
275
275
  ) -> torch.Tensor:
276
276
  if input_embeds is None:
@@ -281,7 +281,7 @@ class MixtralModel(nn.Module):
281
281
  for i in range(len(self.layers)):
282
282
  layer = self.layers[i]
283
283
  hidden_states, residual = layer(
284
- positions, hidden_states, input_metadata, residual
284
+ positions, hidden_states, forward_batch, residual
285
285
  )
286
286
  hidden_states, _ = self.norm(hidden_states, residual)
287
287
  return hidden_states
@@ -307,12 +307,12 @@ class MixtralForCausalLM(nn.Module):
307
307
  self,
308
308
  input_ids: torch.Tensor,
309
309
  positions: torch.Tensor,
310
- input_metadata: InputMetadata,
310
+ forward_batch: ForwardBatch,
311
311
  input_embeds: torch.Tensor = None,
312
312
  ) -> torch.Tensor:
313
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
313
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
314
314
  return self.logits_processor(
315
- input_ids, hidden_states, self.lm_head.weight, input_metadata
315
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
316
316
  )
317
317
 
318
318
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
45
45
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
46
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
48
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
 
50
50
 
51
51
  class MixtralMLP(nn.Module):
@@ -216,12 +216,12 @@ class MixtralAttention(nn.Module):
216
216
  self,
217
217
  positions: torch.Tensor,
218
218
  hidden_states: torch.Tensor,
219
- input_metadata: InputMetadata,
219
+ forward_batch: ForwardBatch,
220
220
  ) -> torch.Tensor:
221
221
  qkv, _ = self.qkv_proj(hidden_states)
222
222
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
223
223
  q, k = self.rotary_emb(positions, q, k)
224
- attn_output = self.attn(q, k, v, input_metadata)
224
+ attn_output = self.attn(q, k, v, forward_batch)
225
225
  output, _ = self.o_proj(attn_output)
226
226
  return output
227
227
 
@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module):
256
256
  self,
257
257
  positions: torch.Tensor,
258
258
  hidden_states: torch.Tensor,
259
- input_metadata: InputMetadata,
259
+ forward_batch: ForwardBatch,
260
260
  residual: Optional[torch.Tensor],
261
261
  ) -> torch.Tensor:
262
262
  # Self Attention
@@ -268,7 +268,7 @@ class MixtralDecoderLayer(nn.Module):
268
268
  hidden_states = self.self_attn(
269
269
  positions=positions,
270
270
  hidden_states=hidden_states,
271
- input_metadata=input_metadata,
271
+ forward_batch=forward_batch,
272
272
  )
273
273
 
274
274
  # Fully Connected
@@ -303,7 +303,7 @@ class MixtralModel(nn.Module):
303
303
  self,
304
304
  input_ids: torch.Tensor,
305
305
  positions: torch.Tensor,
306
- input_metadata: InputMetadata,
306
+ forward_batch: ForwardBatch,
307
307
  input_embeds: torch.Tensor = None,
308
308
  ) -> torch.Tensor:
309
309
  if input_embeds is None:
@@ -314,7 +314,7 @@ class MixtralModel(nn.Module):
314
314
  for i in range(len(self.layers)):
315
315
  layer = self.layers[i]
316
316
  hidden_states, residual = layer(
317
- positions, hidden_states, input_metadata, residual
317
+ positions, hidden_states, forward_batch, residual
318
318
  )
319
319
  hidden_states, _ = self.norm(hidden_states, residual)
320
320
  return hidden_states
@@ -339,12 +339,12 @@ class QuantMixtralForCausalLM(nn.Module):
339
339
  self,
340
340
  input_ids: torch.Tensor,
341
341
  positions: torch.Tensor,
342
- input_metadata: InputMetadata,
342
+ forward_batch: ForwardBatch,
343
343
  input_embeds: torch.Tensor = None,
344
344
  ) -> torch.Tensor:
345
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
345
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
346
346
  return self.logits_processor(
347
- input_ids, hidden_states, self.lm_head.weight, input_metadata
347
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
348
348
  )
349
349
 
350
350
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -48,7 +48,7 @@ from sglang.srt.layers.layernorm import RMSNorm
48
48
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
49
49
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
50
50
  from sglang.srt.layers.radix_attention import RadixAttention
51
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
52
 
53
53
 
54
54
  class OlmoeMoE(nn.Module):
@@ -175,13 +175,13 @@ class OlmoeAttention(nn.Module):
175
175
  self,
176
176
  positions: torch.Tensor,
177
177
  hidden_states: torch.Tensor,
178
- input_metadata: InputMetadata,
178
+ forward_batch: ForwardBatch,
179
179
  ) -> torch.Tensor:
180
180
  qkv, _ = self.qkv_proj(hidden_states)
181
181
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
182
182
  q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
183
183
  q, k = self.rotary_emb(positions, q, k)
184
- attn_output = self.attn(q, k, v, input_metadata)
184
+ attn_output = self.attn(q, k, v, forward_batch)
185
185
  output, _ = self.o_proj(attn_output)
186
186
  return output
187
187
 
@@ -225,7 +225,7 @@ class OlmoeDecoderLayer(nn.Module):
225
225
  self,
226
226
  positions: torch.Tensor,
227
227
  hidden_states: torch.Tensor,
228
- input_metadata: InputMetadata,
228
+ forward_batch: ForwardBatch,
229
229
  residual: Optional[torch.Tensor],
230
230
  ) -> torch.Tensor:
231
231
  # Self Attention
@@ -238,7 +238,7 @@ class OlmoeDecoderLayer(nn.Module):
238
238
  hidden_states = self.self_attn(
239
239
  positions=positions,
240
240
  hidden_states=hidden_states,
241
- input_metadata=input_metadata,
241
+ forward_batch=forward_batch,
242
242
  )
243
243
 
244
244
  # Fully Connected
@@ -274,7 +274,7 @@ class OlmoeModel(nn.Module):
274
274
  self,
275
275
  input_ids: torch.Tensor,
276
276
  positions: torch.Tensor,
277
- input_metadata: InputMetadata,
277
+ forward_batch: ForwardBatch,
278
278
  input_embeds: torch.Tensor = None,
279
279
  ) -> torch.Tensor:
280
280
  if input_embeds is None:
@@ -285,7 +285,7 @@ class OlmoeModel(nn.Module):
285
285
  for i in range(len(self.layers)):
286
286
  layer = self.layers[i]
287
287
  hidden_states, residual = layer(
288
- positions, hidden_states, input_metadata, residual
288
+ positions, hidden_states, forward_batch, residual
289
289
  )
290
290
  hidden_states, _ = self.norm(hidden_states, residual)
291
291
  return hidden_states
@@ -314,12 +314,12 @@ class OlmoeForCausalLM(nn.Module):
314
314
  self,
315
315
  input_ids: torch.Tensor,
316
316
  positions: torch.Tensor,
317
- input_metadata: InputMetadata,
317
+ forward_batch: ForwardBatch,
318
318
  input_embeds: torch.Tensor = None,
319
319
  ) -> torch.Tensor:
320
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
320
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
321
321
  return self.logits_processor(
322
- input_ids, hidden_states, self.lm_head.weight, input_metadata
322
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
323
323
  )
324
324
 
325
325
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/qwen.py CHANGED
@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
39
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
 
44
44
 
45
45
  class QWenMLP(nn.Module):
@@ -133,12 +133,12 @@ class QWenAttention(nn.Module):
133
133
  self,
134
134
  positions: torch.Tensor,
135
135
  hidden_states: torch.Tensor,
136
- input_metadata: InputMetadata,
136
+ forward_batch: ForwardBatch,
137
137
  ) -> torch.Tensor:
138
138
  qkv, _ = self.c_attn(hidden_states)
139
139
  q, k, v = qkv.chunk(chunks=3, dim=-1)
140
140
  q, k = self.rotary_emb(positions, q, k)
141
- attn_output = self.attn(q, k, v, input_metadata)
141
+ attn_output = self.attn(q, k, v, forward_batch)
142
142
  output, _ = self.c_proj(attn_output)
143
143
  return output
144
144
 
@@ -177,7 +177,7 @@ class QWenBlock(nn.Module):
177
177
  self,
178
178
  positions: torch.Tensor,
179
179
  hidden_states: torch.Tensor,
180
- input_metadata: InputMetadata,
180
+ forward_batch: ForwardBatch,
181
181
  ) -> torch.Tensor:
182
182
  # Self Attention
183
183
  residual = hidden_states
@@ -185,7 +185,7 @@ class QWenBlock(nn.Module):
185
185
  hidden_states = self.attn(
186
186
  positions=positions,
187
187
  hidden_states=hidden_states,
188
- input_metadata=input_metadata,
188
+ forward_batch=forward_batch,
189
189
  )
190
190
  hidden_states = residual + hidden_states
191
191
 
@@ -224,7 +224,7 @@ class QWenModel(nn.Module):
224
224
  self,
225
225
  input_ids: torch.Tensor,
226
226
  positions: torch.Tensor,
227
- input_metadata: InputMetadata,
227
+ forward_batch: ForwardBatch,
228
228
  ) -> torch.Tensor:
229
229
  hidden_states = self.wte(input_ids)
230
230
  for i in range(len(self.h)):
@@ -232,7 +232,7 @@ class QWenModel(nn.Module):
232
232
  hidden_states = layer(
233
233
  positions,
234
234
  hidden_states,
235
- input_metadata,
235
+ forward_batch,
236
236
  )
237
237
  hidden_states = self.ln_f(hidden_states)
238
238
  return hidden_states
@@ -257,11 +257,11 @@ class QWenLMHeadModel(nn.Module):
257
257
  self,
258
258
  input_ids: torch.Tensor,
259
259
  positions: torch.Tensor,
260
- input_metadata: InputMetadata,
260
+ forward_batch: ForwardBatch,
261
261
  ):
262
- hidden_states = self.transformer(input_ids, positions, input_metadata)
262
+ hidden_states = self.transformer(input_ids, positions, forward_batch)
263
263
  return self.logits_processor(
264
- input_ids, hidden_states, self.lm_head.weight, input_metadata
264
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
265
265
  )
266
266
 
267
267
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -40,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
40
40
  from sglang.srt.layers.pooler import Pooler, PoolingType
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
44
 
45
45
  Qwen2Config = None
46
46
 
@@ -149,12 +149,12 @@ class Qwen2Attention(nn.Module):
149
149
  self,
150
150
  positions: torch.Tensor,
151
151
  hidden_states: torch.Tensor,
152
- input_metadata: InputMetadata,
152
+ forward_batch: ForwardBatch,
153
153
  ) -> torch.Tensor:
154
154
  qkv, _ = self.qkv_proj(hidden_states)
155
155
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
156
156
  q, k = self.rotary_emb(positions, q, k)
157
- attn_output = self.attn(q, k, v, input_metadata)
157
+ attn_output = self.attn(q, k, v, forward_batch)
158
158
  output, _ = self.o_proj(attn_output)
159
159
  return output
160
160
 
@@ -196,7 +196,7 @@ class Qwen2DecoderLayer(nn.Module):
196
196
  self,
197
197
  positions: torch.Tensor,
198
198
  hidden_states: torch.Tensor,
199
- input_metadata: InputMetadata,
199
+ forward_batch: ForwardBatch,
200
200
  residual: Optional[torch.Tensor],
201
201
  ) -> Tuple[torch.Tensor, torch.Tensor]:
202
202
  # Self Attention
@@ -208,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
208
208
  hidden_states = self.self_attn(
209
209
  positions=positions,
210
210
  hidden_states=hidden_states,
211
- input_metadata=input_metadata,
211
+ forward_batch=forward_batch,
212
212
  )
213
213
 
214
214
  # Fully Connected
@@ -243,7 +243,7 @@ class Qwen2Model(nn.Module):
243
243
  self,
244
244
  input_ids: torch.Tensor,
245
245
  positions: torch.Tensor,
246
- input_metadata: InputMetadata,
246
+ forward_batch: ForwardBatch,
247
247
  input_embeds: torch.Tensor = None,
248
248
  ) -> torch.Tensor:
249
249
  if input_embeds is None:
@@ -256,7 +256,7 @@ class Qwen2Model(nn.Module):
256
256
  hidden_states, residual = layer(
257
257
  positions,
258
258
  hidden_states,
259
- input_metadata,
259
+ forward_batch,
260
260
  residual,
261
261
  )
262
262
  hidden_states, _ = self.norm(hidden_states, residual)
@@ -283,17 +283,17 @@ class Qwen2ForCausalLM(nn.Module):
283
283
  self,
284
284
  input_ids: torch.Tensor,
285
285
  positions: torch.Tensor,
286
- input_metadata: InputMetadata,
286
+ forward_batch: ForwardBatch,
287
287
  input_embeds: torch.Tensor = None,
288
288
  get_embedding: bool = False,
289
289
  ) -> torch.Tensor:
290
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
290
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
291
291
  if not get_embedding:
292
292
  return self.logits_processor(
293
- input_ids, hidden_states, self.lm_head.weight, input_metadata
293
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
294
294
  )
295
295
  else:
296
- return self.pooler(hidden_states, input_metadata)
296
+ return self.pooler(hidden_states, forward_batch)
297
297
 
298
298
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
299
299
  stacked_params_mapping = [
@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
49
  from sglang.srt.layers.radix_attention import RadixAttention
50
50
  from sglang.srt.layers.torchao_utils import apply_torchao_config_
51
51
  from sglang.srt.managers.schedule_batch import global_server_args_dict
52
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
53
53
 
54
54
 
55
55
  class Qwen2MoeMLP(nn.Module):
@@ -221,12 +221,12 @@ class Qwen2MoeAttention(nn.Module):
221
221
  self,
222
222
  positions: torch.Tensor,
223
223
  hidden_states: torch.Tensor,
224
- input_metadata: InputMetadata,
224
+ forward_batch: ForwardBatch,
225
225
  ) -> torch.Tensor:
226
226
  qkv, _ = self.qkv_proj(hidden_states)
227
227
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
228
228
  q, k = self.rotary_emb(positions, q, k)
229
- attn_output = self.attn(q, k, v, input_metadata)
229
+ attn_output = self.attn(q, k, v, forward_batch)
230
230
  output, _ = self.o_proj(attn_output)
231
231
  return output
232
232
 
@@ -281,7 +281,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
281
281
  self,
282
282
  positions: torch.Tensor,
283
283
  hidden_states: torch.Tensor,
284
- input_metadata: InputMetadata,
284
+ forward_batch: ForwardBatch,
285
285
  residual: Optional[torch.Tensor],
286
286
  ) -> torch.Tensor:
287
287
  # Self Attention
@@ -293,7 +293,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
293
293
  hidden_states = self.self_attn(
294
294
  positions=positions,
295
295
  hidden_states=hidden_states,
296
- input_metadata=input_metadata,
296
+ forward_batch=forward_batch,
297
297
  )
298
298
 
299
299
  # Fully Connected
@@ -331,7 +331,7 @@ class Qwen2MoeModel(nn.Module):
331
331
  self,
332
332
  input_ids: torch.Tensor,
333
333
  positions: torch.Tensor,
334
- input_metadata: InputMetadata,
334
+ forward_batch: ForwardBatch,
335
335
  input_embeds: torch.Tensor = None,
336
336
  ) -> torch.Tensor:
337
337
  if input_embeds is None:
@@ -342,7 +342,7 @@ class Qwen2MoeModel(nn.Module):
342
342
  for i in range(len(self.layers)):
343
343
  layer = self.layers[i]
344
344
  hidden_states, residual = layer(
345
- positions, hidden_states, input_metadata, residual
345
+ positions, hidden_states, forward_batch, residual
346
346
  )
347
347
  hidden_states, _ = self.norm(hidden_states, residual)
348
348
  return hidden_states
@@ -373,12 +373,12 @@ class Qwen2MoeForCausalLM(nn.Module):
373
373
  self,
374
374
  input_ids: torch.Tensor,
375
375
  positions: torch.Tensor,
376
- input_metadata: InputMetadata,
376
+ forward_batch: ForwardBatch,
377
377
  input_embeds: torch.Tensor = None,
378
378
  ) -> torch.Tensor:
379
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
379
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
380
380
  return self.logits_processor(
381
- input_ids, hidden_states, self.lm_head.weight, input_metadata
381
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
382
382
  )
383
383
 
384
384
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
44
 
45
45
 
46
46
  class StablelmMLP(nn.Module):
@@ -145,12 +145,12 @@ class StablelmAttention(nn.Module):
145
145
  self,
146
146
  positions: torch.Tensor,
147
147
  hidden_states: torch.Tensor,
148
- input_metadata: InputMetadata,
148
+ forward_batch: ForwardBatch,
149
149
  ) -> torch.Tensor:
150
150
  qkv, _ = self.qkv_proj(hidden_states)
151
151
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
152
152
  q, k = self.rotary_emb(positions, q, k)
153
- attn_output = self.attn(q, k, v, input_metadata)
153
+ attn_output = self.attn(q, k, v, forward_batch)
154
154
  output, _ = self.o_proj(attn_output)
155
155
  return output
156
156
 
@@ -173,7 +173,7 @@ class StablelmDecoderLayer(nn.Module):
173
173
  self,
174
174
  positions: torch.Tensor,
175
175
  hidden_states: torch.Tensor,
176
- input_metadata: InputMetadata,
176
+ forward_batch: ForwardBatch,
177
177
  ) -> Tuple[torch.Tensor, torch.Tensor]:
178
178
  # Self Attention
179
179
  residual = hidden_states
@@ -181,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
181
181
  hidden_states = self.self_attn(
182
182
  positions=positions,
183
183
  hidden_states=hidden_states,
184
- input_metadata=input_metadata,
184
+ forward_batch=forward_batch,
185
185
  )
186
186
  hidden_states = residual + hidden_states
187
187
 
@@ -218,7 +218,7 @@ class StableLMEpochModel(nn.Module):
218
218
  self,
219
219
  input_ids: torch.Tensor,
220
220
  positions: torch.Tensor,
221
- input_metadata: InputMetadata,
221
+ forward_batch: ForwardBatch,
222
222
  input_embeds: torch.Tensor = None,
223
223
  ) -> torch.Tensor:
224
224
  if input_embeds is None:
@@ -230,7 +230,7 @@ class StableLMEpochModel(nn.Module):
230
230
  hidden_states, residual = layer(
231
231
  positions,
232
232
  hidden_states,
233
- input_metadata,
233
+ forward_batch,
234
234
  )
235
235
  hidden_states = self.norm(hidden_states)
236
236
  return hidden_states
@@ -255,12 +255,12 @@ class StableLmForCausalLM(nn.Module):
255
255
  self,
256
256
  input_ids: torch.Tensor,
257
257
  positions: torch.Tensor,
258
- input_metadata: InputMetadata,
258
+ forward_batch: ForwardBatch,
259
259
  input_embeds: torch.Tensor = None,
260
260
  ) -> torch.Tensor:
261
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
261
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
262
262
  return self.logits_processor(
263
- input_ids, hidden_states, self.lm_head.weight, input_metadata
263
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
264
264
  )
265
265
 
266
266
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):