sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
50
 
51
51
 
52
52
  class DeepseekMLP(nn.Module):
@@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module):
246
246
  self,
247
247
  positions: torch.Tensor,
248
248
  hidden_states: torch.Tensor,
249
- input_metadata: InputMetadata,
249
+ forward_batch: ForwardBatch,
250
250
  ) -> torch.Tensor:
251
251
  qkv, _ = self.qkv_proj(hidden_states)
252
252
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
253
253
  q, k = self.rotary_emb(positions, q, k)
254
- attn_output = self.attn(q, k, v, input_metadata)
254
+ attn_output = self.attn(q, k, v, forward_batch)
255
255
  output, _ = self.o_proj(attn_output)
256
256
  return output
257
257
 
@@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module):
303
303
  self,
304
304
  positions: torch.Tensor,
305
305
  hidden_states: torch.Tensor,
306
- input_metadata: InputMetadata,
306
+ forward_batch: ForwardBatch,
307
307
  residual: Optional[torch.Tensor],
308
308
  ) -> torch.Tensor:
309
309
  # Self Attention
@@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module):
315
315
  hidden_states = self.self_attn(
316
316
  positions=positions,
317
317
  hidden_states=hidden_states,
318
- input_metadata=input_metadata,
318
+ forward_batch=forward_batch,
319
319
  )
320
320
 
321
321
  # Fully Connected
@@ -356,14 +356,14 @@ class DeepseekModel(nn.Module):
356
356
  self,
357
357
  input_ids: torch.Tensor,
358
358
  positions: torch.Tensor,
359
- input_metadata: InputMetadata,
359
+ forward_batch: ForwardBatch,
360
360
  ) -> torch.Tensor:
361
361
  hidden_states = self.embed_tokens(input_ids)
362
362
  residual = None
363
363
  for i in range(len(self.layers)):
364
364
  layer = self.layers[i]
365
365
  hidden_states, residual = layer(
366
- positions, hidden_states, input_metadata, residual
366
+ positions, hidden_states, forward_batch, residual
367
367
  )
368
368
  hidden_states, _ = self.norm(hidden_states, residual)
369
369
  return hidden_states
@@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module):
391
391
  self,
392
392
  input_ids: torch.Tensor,
393
393
  positions: torch.Tensor,
394
- input_metadata: InputMetadata,
394
+ forward_batch: ForwardBatch,
395
395
  ) -> torch.Tensor:
396
- hidden_states = self.model(input_ids, positions, input_metadata)
396
+ hidden_states = self.model(input_ids, positions, forward_batch)
397
397
  return self.logits_processor(
398
- input_ids, hidden_states, self.lm_head.weight, input_metadata
398
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
399
399
  )
400
400
 
401
401
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -46,11 +46,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
46
46
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
- from sglang.srt.utils import is_hip
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.utils import is_flashinfer_available
51
51
 
52
- # ROCm: flashinfer available later
53
- if not is_hip():
52
+ if is_flashinfer_available():
54
53
  from flashinfer import bmm_fp8
55
54
 
56
55
 
@@ -281,7 +280,7 @@ class DeepseekV2Attention(nn.Module):
281
280
  self,
282
281
  positions: torch.Tensor,
283
282
  hidden_states: torch.Tensor,
284
- input_metadata: InputMetadata,
283
+ forward_batch: ForwardBatch,
285
284
  ) -> torch.Tensor:
286
285
  if self.q_lora_rank is not None:
287
286
  q = self.q_a_proj(hidden_states)[0]
@@ -314,7 +313,7 @@ class DeepseekV2Attention(nn.Module):
314
313
  v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
315
314
  -1, self.num_local_heads * 256
316
315
  )
317
- attn_output = self.attn(q, k, v, input_metadata)
316
+ attn_output = self.attn(q, k, v, forward_batch)
318
317
  attn_output = attn_output.view(-1, self.num_local_heads, 256)[
319
318
  ..., : self.v_head_dim
320
319
  ].reshape(-1, self.num_local_heads * self.v_head_dim)
@@ -433,7 +432,7 @@ class DeepseekV2AttentionMLA(nn.Module):
433
432
  self,
434
433
  positions: torch.Tensor,
435
434
  hidden_states: torch.Tensor,
436
- input_metadata: InputMetadata,
435
+ forward_batch: ForwardBatch,
437
436
  ) -> torch.Tensor:
438
437
  q_len = hidden_states.shape[0]
439
438
  q_input = hidden_states.new_empty(
@@ -471,7 +470,7 @@ class DeepseekV2AttentionMLA(nn.Module):
471
470
  q_input[..., self.kv_lora_rank :] = q_pe
472
471
  k_input[..., self.kv_lora_rank :] = k_pe
473
472
 
474
- attn_output = self.attn(q_input, k_input, v_input, input_metadata)
473
+ attn_output = self.attn(q_input, k_input, v_input, forward_batch)
475
474
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
476
475
 
477
476
  if self.w_vc.dtype == torch.float8_e4m3fn:
@@ -567,7 +566,7 @@ class DeepseekV2DecoderLayer(nn.Module):
567
566
  self,
568
567
  positions: torch.Tensor,
569
568
  hidden_states: torch.Tensor,
570
- input_metadata: InputMetadata,
569
+ forward_batch: ForwardBatch,
571
570
  residual: Optional[torch.Tensor],
572
571
  ) -> torch.Tensor:
573
572
  # Self Attention
@@ -579,7 +578,7 @@ class DeepseekV2DecoderLayer(nn.Module):
579
578
  hidden_states = self.self_attn(
580
579
  positions=positions,
581
580
  hidden_states=hidden_states,
582
- input_metadata=input_metadata,
581
+ forward_batch=forward_batch,
583
582
  )
584
583
 
585
584
  # Fully Connected
@@ -623,14 +622,14 @@ class DeepseekV2Model(nn.Module):
623
622
  self,
624
623
  input_ids: torch.Tensor,
625
624
  positions: torch.Tensor,
626
- input_metadata: InputMetadata,
625
+ forward_batch: ForwardBatch,
627
626
  ) -> torch.Tensor:
628
627
  hidden_states = self.embed_tokens(input_ids)
629
628
  residual = None
630
629
  for i in range(len(self.layers)):
631
630
  layer = self.layers[i]
632
631
  hidden_states, residual = layer(
633
- positions, hidden_states, input_metadata, residual
632
+ positions, hidden_states, forward_batch, residual
634
633
  )
635
634
  hidden_states, _ = self.norm(hidden_states, residual)
636
635
  return hidden_states
@@ -658,11 +657,11 @@ class DeepseekV2ForCausalLM(nn.Module):
658
657
  self,
659
658
  input_ids: torch.Tensor,
660
659
  positions: torch.Tensor,
661
- input_metadata: InputMetadata,
660
+ forward_batch: ForwardBatch,
662
661
  ) -> torch.Tensor:
663
- hidden_states = self.model(input_ids, positions, input_metadata)
662
+ hidden_states = self.model(input_ids, positions, forward_batch)
664
663
  return self.logits_processor(
665
- input_ids, hidden_states, self.lm_head.weight, input_metadata
664
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
666
665
  )
667
666
 
668
667
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
44
 
45
45
 
46
46
  class ExaoneGatedMLP(nn.Module):
@@ -162,12 +162,12 @@ class ExaoneAttention(nn.Module):
162
162
  self,
163
163
  positions: torch.Tensor,
164
164
  hidden_states: torch.Tensor,
165
- input_metadata: InputMetadata,
165
+ forward_batch: ForwardBatch,
166
166
  ) -> torch.Tensor:
167
167
  qkv, _ = self.qkv_proj(hidden_states)
168
168
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
169
169
  q, k = self.rotary_emb(positions, q, k)
170
- attn_output = self.attn(q, k, v, input_metadata)
170
+ attn_output = self.attn(q, k, v, forward_batch)
171
171
  output, _ = self.out_proj(attn_output)
172
172
  return output
173
173
 
@@ -220,7 +220,7 @@ class ExaoneDecoderLayer(nn.Module):
220
220
  self,
221
221
  positions: torch.Tensor,
222
222
  hidden_states: torch.Tensor,
223
- input_metadata: InputMetadata,
223
+ forward_batch: ForwardBatch,
224
224
  residual: Optional[torch.Tensor],
225
225
  ) -> Tuple[torch.Tensor, torch.Tensor]:
226
226
  # Self Attention
@@ -232,7 +232,7 @@ class ExaoneDecoderLayer(nn.Module):
232
232
  hidden_states = self.self_attn(
233
233
  positions=positions,
234
234
  hidden_states=hidden_states,
235
- input_metadata=input_metadata,
235
+ forward_batch=forward_batch,
236
236
  )
237
237
 
238
238
  # Fully Connected
@@ -270,7 +270,7 @@ class ExaoneModel(nn.Module):
270
270
  self,
271
271
  input_ids: torch.Tensor,
272
272
  positions: torch.Tensor,
273
- input_metadata: InputMetadata,
273
+ forward_batch: ForwardBatch,
274
274
  input_embeds: torch.Tensor = None,
275
275
  ) -> torch.Tensor:
276
276
  if input_embeds is None:
@@ -283,7 +283,7 @@ class ExaoneModel(nn.Module):
283
283
  hidden_states, residual = layer(
284
284
  positions,
285
285
  hidden_states,
286
- input_metadata,
286
+ forward_batch,
287
287
  residual,
288
288
  )
289
289
  hidden_states, _ = self.ln_f(hidden_states, residual)
@@ -309,14 +309,14 @@ class ExaoneForCausalLM(nn.Module):
309
309
  self,
310
310
  input_ids: torch.Tensor,
311
311
  positions: torch.Tensor,
312
- input_metadata: InputMetadata,
312
+ forward_batch: ForwardBatch,
313
313
  input_embeds: torch.Tensor = None,
314
314
  ) -> LogitsProcessorOutput:
315
315
  hidden_states = self.transformer(
316
- input_ids, positions, input_metadata, input_embeds
316
+ input_ids, positions, forward_batch, input_embeds
317
317
  )
318
318
  return self.logits_processor(
319
- input_ids, hidden_states, self.lm_head.weight, input_metadata
319
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
320
320
  )
321
321
 
322
322
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
37
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
40
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
41
 
42
42
 
43
43
  class GemmaMLP(nn.Module):
@@ -137,12 +137,12 @@ class GemmaAttention(nn.Module):
137
137
  self,
138
138
  positions: torch.Tensor,
139
139
  hidden_states: torch.Tensor,
140
- input_metadata: InputMetadata,
140
+ forward_batch: ForwardBatch,
141
141
  ) -> torch.Tensor:
142
142
  qkv, _ = self.qkv_proj(hidden_states)
143
143
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
144
144
  q, k = self.rotary_emb(positions, q, k)
145
- attn_output = self.attn(q, k, v, input_metadata)
145
+ attn_output = self.attn(q, k, v, forward_batch)
146
146
  output, _ = self.o_proj(attn_output)
147
147
  return output
148
148
 
@@ -180,7 +180,7 @@ class GemmaDecoderLayer(nn.Module):
180
180
  self,
181
181
  positions: torch.Tensor,
182
182
  hidden_states: torch.Tensor,
183
- input_metadata: InputMetadata,
183
+ forward_batch: ForwardBatch,
184
184
  residual: Optional[torch.Tensor],
185
185
  ) -> Tuple[torch.Tensor, torch.Tensor]:
186
186
  # Self Attention
@@ -192,7 +192,7 @@ class GemmaDecoderLayer(nn.Module):
192
192
  hidden_states = self.self_attn(
193
193
  positions=positions,
194
194
  hidden_states=hidden_states,
195
- input_metadata=input_metadata,
195
+ forward_batch=forward_batch,
196
196
  )
197
197
 
198
198
  # Fully Connected
@@ -226,7 +226,7 @@ class GemmaModel(nn.Module):
226
226
  self,
227
227
  input_ids: torch.Tensor,
228
228
  positions: torch.Tensor,
229
- input_metadata: InputMetadata,
229
+ forward_batch: ForwardBatch,
230
230
  input_embeds: torch.Tensor = None,
231
231
  ) -> torch.Tensor:
232
232
  if input_embeds is None:
@@ -243,7 +243,7 @@ class GemmaModel(nn.Module):
243
243
  hidden_states, residual = layer(
244
244
  positions,
245
245
  hidden_states,
246
- input_metadata,
246
+ forward_batch,
247
247
  residual,
248
248
  )
249
249
  hidden_states, _ = self.norm(hidden_states, residual)
@@ -293,12 +293,12 @@ class GemmaForCausalLM(nn.Module):
293
293
  self,
294
294
  input_ids: torch.Tensor,
295
295
  positions: torch.Tensor,
296
- input_metadata: InputMetadata,
296
+ forward_batch: ForwardBatch,
297
297
  input_embeds: torch.Tensor = None,
298
298
  ) -> torch.Tensor:
299
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
299
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
300
300
  return self.logits_processor(
301
- input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
301
+ input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
302
302
  )
303
303
 
304
304
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
37
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
40
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
41
 
42
42
 
43
43
  # Aligned with HF's implementation, using sliding window inclusive with the last token
@@ -163,24 +163,24 @@ class Gemma2Attention(nn.Module):
163
163
  self.scaling,
164
164
  num_kv_heads=self.num_kv_heads,
165
165
  layer_id=layer_idx,
166
+ logit_cap=self.config.attn_logit_softcapping,
166
167
  sliding_window_size=(
167
168
  get_attention_sliding_window_size(config)
168
169
  if use_sliding_window
169
170
  else None
170
171
  ),
171
- logit_cap=self.config.attn_logit_softcapping,
172
172
  )
173
173
 
174
174
  def forward(
175
175
  self,
176
176
  positions: torch.Tensor,
177
177
  hidden_states: torch.Tensor,
178
- input_metadata: InputMetadata,
178
+ forward_batch: ForwardBatch,
179
179
  ) -> torch.Tensor:
180
180
  qkv, _ = self.qkv_proj(hidden_states)
181
181
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
182
182
  q, k = self.rotary_emb(positions, q, k)
183
- attn_output = self.attn(q, k, v, input_metadata)
183
+ attn_output = self.attn(q, k, v, forward_batch)
184
184
  output, _ = self.o_proj(attn_output)
185
185
  return output
186
186
 
@@ -230,7 +230,7 @@ class Gemma2DecoderLayer(nn.Module):
230
230
  self,
231
231
  positions: torch.Tensor,
232
232
  hidden_states: torch.Tensor,
233
- input_metadata: InputMetadata,
233
+ forward_batch: ForwardBatch,
234
234
  residual: Optional[torch.Tensor],
235
235
  ) -> Tuple[torch.Tensor, torch.Tensor]:
236
236
  if residual is None:
@@ -241,7 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
241
241
  hidden_states = self.self_attn(
242
242
  positions=positions,
243
243
  hidden_states=hidden_states,
244
- input_metadata=input_metadata,
244
+ forward_batch=forward_batch,
245
245
  )
246
246
  hidden_states = self.post_attention_layernorm(hidden_states)
247
247
 
@@ -286,7 +286,7 @@ class Gemma2Model(nn.Module):
286
286
  self,
287
287
  input_ids: torch.Tensor,
288
288
  positions: torch.Tensor,
289
- input_metadata: InputMetadata,
289
+ forward_batch: ForwardBatch,
290
290
  input_embeds: torch.Tensor = None,
291
291
  ) -> torch.Tensor:
292
292
  if input_embeds is None:
@@ -302,7 +302,7 @@ class Gemma2Model(nn.Module):
302
302
  hidden_states, residual = layer(
303
303
  positions,
304
304
  hidden_states,
305
- input_metadata,
305
+ forward_batch,
306
306
  residual,
307
307
  )
308
308
  hidden_states, _ = self.norm(hidden_states, residual)
@@ -352,12 +352,12 @@ class Gemma2ForCausalLM(nn.Module):
352
352
  self,
353
353
  input_ids: torch.Tensor,
354
354
  positions: torch.Tensor,
355
- input_metadata: InputMetadata,
355
+ forward_batch: ForwardBatch,
356
356
  input_embeds: torch.Tensor = None,
357
357
  ) -> torch.Tensor:
358
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
358
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
359
359
  return self.logits_processor(
360
- input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
360
+ input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
361
361
  )
362
362
 
363
363
  def get_attention_sliding_window_size(self):
@@ -35,7 +35,7 @@ from sglang.srt.layers.linear import (
35
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
38
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
39
 
40
40
 
41
41
  class GPTBigCodeAttention(nn.Module):
@@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module):
90
90
  def forward(
91
91
  self,
92
92
  hidden_states: torch.Tensor,
93
- input_metadata: InputMetadata,
93
+ forward_batch: ForwardBatch,
94
94
  ) -> torch.Tensor:
95
95
  qkv, _ = self.c_attn(hidden_states)
96
96
  q, k, v = qkv.split(
@@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module):
101
101
  ],
102
102
  dim=-1,
103
103
  )
104
- attn_output = self.attn(q, k, v, input_metadata)
104
+ attn_output = self.attn(q, k, v, forward_batch)
105
105
  attn_output, _ = self.c_proj(attn_output)
106
106
  return attn_output
107
107
 
@@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module):
160
160
  def forward(
161
161
  self,
162
162
  hidden_states: torch.Tensor,
163
- input_metadata: InputMetadata,
163
+ forward_batch: ForwardBatch,
164
164
  ) -> torch.Tensor:
165
165
  residual = hidden_states
166
166
  hidden_states = self.ln_1(hidden_states)
167
167
  attn_output = self.attn(
168
- hidden_states=hidden_states, input_metadata=input_metadata
168
+ hidden_states=hidden_states, forward_batch=forward_batch
169
169
  )
170
170
  # residual connection
171
171
  hidden_states = attn_output + residual
@@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module):
214
214
  self,
215
215
  input_ids: torch.Tensor,
216
216
  position_ids: torch.Tensor,
217
- input_metadata: InputMetadata,
217
+ forward_batch: ForwardBatch,
218
218
  ) -> torch.Tensor:
219
219
  inputs_embeds = self.wte(input_ids)
220
220
  position_embeds = self.wpe(position_ids)
@@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module):
222
222
 
223
223
  for i in range(len(self.h)):
224
224
  layer = self.h[i]
225
- hidden_states = layer(hidden_states, input_metadata)
225
+ hidden_states = layer(hidden_states, forward_batch)
226
226
 
227
227
  hidden_states = self.ln_f(hidden_states)
228
228
  return hidden_states
@@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module):
267
267
  self,
268
268
  input_ids: torch.Tensor,
269
269
  positions: torch.Tensor,
270
- input_metadata: InputMetadata,
270
+ forward_batch: ForwardBatch,
271
271
  ) -> torch.Tensor:
272
- hidden_states = self.transformer(input_ids, positions, input_metadata)
272
+ hidden_states = self.transformer(input_ids, positions, forward_batch)
273
273
  return self.logits_processor(
274
- input_ids, hidden_states, self.lm_head.weight, input_metadata
274
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
275
275
  )
276
276
 
277
277
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/grok.py CHANGED
@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
50
 
51
51
 
52
52
  class Grok1MoE(nn.Module):
@@ -173,12 +173,12 @@ class Grok1Attention(nn.Module):
173
173
  self,
174
174
  positions: torch.Tensor,
175
175
  hidden_states: torch.Tensor,
176
- input_metadata: InputMetadata,
176
+ forward_batch: ForwardBatch,
177
177
  ) -> torch.Tensor:
178
178
  qkv, _ = self.qkv_proj(hidden_states)
179
179
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
180
180
  q, k = self.rotary_emb(positions, q, k)
181
- attn_output = self.attn(q, k, v, input_metadata)
181
+ attn_output = self.attn(q, k, v, forward_batch)
182
182
  output, _ = self.o_proj(attn_output)
183
183
  return output
184
184
 
@@ -219,7 +219,7 @@ class Grok1DecoderLayer(nn.Module):
219
219
  self,
220
220
  positions: torch.Tensor,
221
221
  hidden_states: torch.Tensor,
222
- input_metadata: InputMetadata,
222
+ forward_batch: ForwardBatch,
223
223
  ) -> torch.Tensor:
224
224
  # Self Attention
225
225
  hidden_states = (
@@ -227,7 +227,7 @@ class Grok1DecoderLayer(nn.Module):
227
227
  self.self_attn(
228
228
  positions=positions,
229
229
  hidden_states=self.pre_attn_norm(hidden_states),
230
- input_metadata=input_metadata,
230
+ forward_batch=forward_batch,
231
231
  )
232
232
  )
233
233
  + hidden_states
@@ -268,7 +268,7 @@ class Grok1Model(nn.Module):
268
268
  self,
269
269
  input_ids: torch.Tensor,
270
270
  positions: torch.Tensor,
271
- input_metadata: InputMetadata,
271
+ forward_batch: ForwardBatch,
272
272
  input_embeds: torch.Tensor = None,
273
273
  ) -> torch.Tensor:
274
274
  if input_embeds is None:
@@ -278,7 +278,7 @@ class Grok1Model(nn.Module):
278
278
  hidden_states = input_embeds
279
279
 
280
280
  for i in range(len(self.layers)):
281
- hidden_states = self.layers[i](positions, hidden_states, input_metadata)
281
+ hidden_states = self.layers[i](positions, hidden_states, forward_batch)
282
282
  hidden_states = self.norm(hidden_states)
283
283
  hidden_states.mul_(self.config.output_multiplier_scale)
284
284
  return hidden_states
@@ -309,12 +309,12 @@ class Grok1ForCausalLM(nn.Module):
309
309
  self,
310
310
  input_ids: torch.Tensor,
311
311
  positions: torch.Tensor,
312
- input_metadata: InputMetadata,
312
+ forward_batch: ForwardBatch,
313
313
  input_embeds: torch.Tensor = None,
314
314
  ) -> torch.Tensor:
315
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
315
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
316
316
  return self.logits_processor(
317
- input_ids, hidden_states, self.lm_head.weight, input_metadata
317
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
318
318
  )
319
319
 
320
320
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
44
 
45
45
 
46
46
  class InternLM2MLP(nn.Module):
@@ -137,12 +137,12 @@ class InternLM2Attention(nn.Module):
137
137
  self,
138
138
  positions: torch.Tensor,
139
139
  hidden_states: torch.Tensor,
140
- input_metadata: InputMetadata,
140
+ forward_batch: ForwardBatch,
141
141
  ) -> torch.Tensor:
142
142
  qkv, _ = self.wqkv(hidden_states)
143
143
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
144
144
  q, k = self.rotary_emb(positions, q, k)
145
- attn_output = self.attn(q, k, v, input_metadata)
145
+ attn_output = self.attn(q, k, v, forward_batch)
146
146
  output, _ = self.wo(attn_output)
147
147
  return output
148
148
 
@@ -182,7 +182,7 @@ class InternLMDecoderLayer(nn.Module):
182
182
  self,
183
183
  positions: torch.Tensor,
184
184
  hidden_states: torch.Tensor,
185
- input_metadata: InputMetadata,
185
+ forward_batch: ForwardBatch,
186
186
  residual: Optional[torch.Tensor],
187
187
  ) -> Tuple[torch.Tensor, torch.Tensor]:
188
188
  # Self Attention
@@ -194,7 +194,7 @@ class InternLMDecoderLayer(nn.Module):
194
194
  hidden_states = self.attention(
195
195
  positions=positions,
196
196
  hidden_states=hidden_states,
197
- input_metadata=input_metadata,
197
+ forward_batch=forward_batch,
198
198
  )
199
199
 
200
200
  # Fully Connected
@@ -229,7 +229,7 @@ class InternLM2Model(nn.Module):
229
229
  self,
230
230
  input_ids: torch.Tensor,
231
231
  positions: torch.Tensor,
232
- input_metadata: InputMetadata,
232
+ forward_batch: ForwardBatch,
233
233
  input_embeds: torch.Tensor = None,
234
234
  ) -> torch.Tensor:
235
235
  if input_embeds is None:
@@ -242,7 +242,7 @@ class InternLM2Model(nn.Module):
242
242
  hidden_states, residual = layer(
243
243
  positions,
244
244
  hidden_states,
245
- input_metadata,
245
+ forward_batch,
246
246
  residual,
247
247
  )
248
248
  hidden_states, _ = self.norm(hidden_states, residual)
@@ -268,12 +268,12 @@ class InternLM2ForCausalLM(nn.Module):
268
268
  self,
269
269
  input_ids: torch.Tensor,
270
270
  positions: torch.Tensor,
271
- input_metadata: InputMetadata,
271
+ forward_batch: ForwardBatch,
272
272
  input_embeds: torch.Tensor = None,
273
273
  ) -> torch.Tensor:
274
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
274
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
275
275
  return self.logits_processor(
276
- input_ids, hidden_states, self.output.weight, input_metadata
276
+ input_ids, hidden_states, self.output.weight, forward_batch
277
277
  )
278
278
 
279
279
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):