sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import importlib.resources
21
21
  import logging
22
22
  import pkgutil
23
23
  from functools import lru_cache
24
- from typing import Optional, Tuple, Type
24
+ from typing import Optional, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
@@ -38,20 +38,23 @@ from vllm.model_executor.model_loader import get_model
38
38
  from vllm.model_executor.models import ModelRegistry
39
39
 
40
40
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
- from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
41
+ from sglang.srt.constrained import disable_cache
42
+ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
43
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
42
44
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
43
45
  from sglang.srt.layers.sampler import Sampler
44
46
  from sglang.srt.lora.lora_manager import LoRAManager
45
- from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
47
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
46
48
  from sglang.srt.mem_cache.memory_pool import (
47
49
  MHATokenToKVPool,
48
50
  MLATokenToKVPool,
49
51
  ReqToTokenPool,
50
52
  )
51
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
53
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
54
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
53
55
  from sglang.srt.server_args import ServerArgs
54
56
  from sglang.srt.utils import (
57
+ enable_show_time_cost,
55
58
  get_available_gpu_memory,
56
59
  is_generation_model,
57
60
  is_multimodal_model,
@@ -87,6 +90,7 @@ class ModelRunner:
87
90
  self.model_config.hf_config.architectures
88
91
  )
89
92
 
93
+ # Model-specific adjustment
90
94
  if (
91
95
  self.model_config.attention_arch == AttentionArch.MLA
92
96
  and not self.server_args.disable_mla
@@ -94,6 +98,19 @@ class ModelRunner:
94
98
  logger.info("MLA optimization is tunred on. Use triton backend.")
95
99
  self.server_args.attention_backend = "triton"
96
100
 
101
+ if self.is_multimodal_model:
102
+ logger.info(
103
+ "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
104
+ )
105
+ server_args.chunked_prefill_size = None
106
+ server_args.mem_fraction_static *= 0.95
107
+
108
+ # Global vars
109
+ if server_args.show_time_cost:
110
+ enable_show_time_cost()
111
+ if server_args.disable_disk_cache:
112
+ disable_cache()
113
+
97
114
  global_server_args_dict.update(
98
115
  {
99
116
  "attention_backend": server_args.attention_backend,
@@ -104,14 +121,6 @@ class ModelRunner:
104
121
  }
105
122
  )
106
123
 
107
- # Model-specific adjustment
108
- if self.is_multimodal_model:
109
- logger.info(
110
- "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
111
- )
112
- server_args.chunked_prefill_size = None
113
- server_args.mem_fraction_static *= 0.95
114
-
115
124
  # Init componnets
116
125
  min_per_gpu_memory = self.init_torch_distributed()
117
126
  self.sampler = Sampler()
@@ -135,8 +144,8 @@ class ModelRunner:
135
144
  if not self.server_args.enable_p2p_check:
136
145
  monkey_patch_vllm_p2p_access_check(self.gpu_id)
137
146
 
138
- if self.server_args.nccl_init_addr:
139
- nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
147
+ if self.server_args.dist_init_addr:
148
+ nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
140
149
  else:
141
150
  nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
142
151
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
@@ -222,6 +231,7 @@ class ModelRunner:
222
231
  if hasattr(self.model, "get_attention_sliding_window_size")
223
232
  else None
224
233
  )
234
+ self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
225
235
  self.is_generation = is_generation_model(
226
236
  self.model_config.hf_config.architectures, self.server_args.is_embedding
227
237
  )
@@ -399,9 +409,11 @@ class ModelRunner:
399
409
  4096,
400
410
  )
401
411
 
412
+ device = "cuda"
402
413
  self.req_to_token_pool = ReqToTokenPool(
403
- max_num_reqs + 1,
404
- self.model_config.context_len + 4,
414
+ size=max_num_reqs + 1,
415
+ max_context_len=self.model_config.context_len + 4,
416
+ device=device,
405
417
  )
406
418
  if (
407
419
  self.model_config.attention_arch == AttentionArch.MLA
@@ -413,6 +425,7 @@ class ModelRunner:
413
425
  kv_lora_rank=self.model_config.kv_lora_rank,
414
426
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
415
427
  layer_num=self.model_config.num_hidden_layers,
428
+ device=device,
416
429
  )
417
430
  else:
418
431
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -421,6 +434,7 @@ class ModelRunner:
421
434
  head_num=self.model_config.get_num_kv_heads(self.tp_size),
422
435
  head_dim=self.model_config.head_dim,
423
436
  layer_num=self.model_config.num_hidden_layers,
437
+ device=device,
424
438
  )
425
439
  logger.info(
426
440
  f"Memory pool end. "
@@ -445,6 +459,10 @@ class ModelRunner:
445
459
  "Window attention is not supported in the triton attention backend. "
446
460
  "Please use `--attention-backend flashinfer`."
447
461
  )
462
+ assert not self.has_cross_attention, (
463
+ "Cross attention is not supported in the triton attention backend. "
464
+ "Please use `--attention-backend flashinfer`."
465
+ )
448
466
  self.attn_backend = TritonAttnBackend(self)
449
467
  else:
450
468
  raise ValueError(
@@ -467,73 +485,59 @@ class ModelRunner:
467
485
  logger.info("Capture cuda graph begin. This can take up to several minutes.")
468
486
  self.cuda_graph_runner = CudaGraphRunner(self)
469
487
 
470
- @torch.inference_mode()
471
- def forward_decode(self, batch: ScheduleBatch):
472
- if self.server_args.lora_paths is not None:
473
- self.lora_manager.prepare_lora_batch(batch)
474
-
475
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
476
- return self.cuda_graph_runner.replay(batch)
477
-
478
- input_metadata = InputMetadata.from_schedule_batch(self, batch)
488
+ def forward_decode(self, forward_batch: ForwardBatch):
489
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
490
+ forward_batch.batch_size
491
+ ):
492
+ return self.cuda_graph_runner.replay(forward_batch)
479
493
 
480
494
  return self.model.forward(
481
- batch.input_ids, input_metadata.positions, input_metadata
495
+ forward_batch.input_ids, forward_batch.positions, forward_batch
482
496
  )
483
497
 
484
- @torch.inference_mode()
485
- def forward_extend(self, batch: ScheduleBatch):
486
- input_metadata = InputMetadata.from_schedule_batch(self, batch)
487
- if self.server_args.lora_paths is not None:
488
- self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
489
-
498
+ def forward_extend(self, forward_batch: ForwardBatch):
490
499
  if self.is_generation:
491
500
  return self.model.forward(
492
- batch.input_ids, input_metadata.positions, input_metadata
501
+ forward_batch.input_ids, forward_batch.positions, forward_batch
493
502
  )
494
503
  else:
495
504
  # Only embedding models have get_embedding parameter
496
505
  return self.model.forward(
497
- batch.input_ids,
498
- input_metadata.positions,
499
- input_metadata,
506
+ forward_batch.input_ids,
507
+ forward_batch.positions,
508
+ forward_batch,
500
509
  get_embedding=True,
501
510
  )
502
511
 
503
- @torch.inference_mode()
504
- def forward_extend_multi_modal(self, batch: ScheduleBatch):
505
- input_metadata = InputMetadata.from_schedule_batch(self, batch)
506
- return self.model.forward(
507
- batch.input_ids,
508
- input_metadata.positions,
509
- input_metadata,
510
- input_metadata.pixel_values,
511
- input_metadata.image_sizes,
512
- input_metadata.image_offsets,
513
- )
512
+ def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
513
+ if forward_batch.forward_mode.is_decode():
514
+ return self.forward_decode(forward_batch)
515
+ elif forward_batch.forward_mode.is_extend():
516
+ return self.forward_extend(forward_batch)
517
+ else:
518
+ raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
514
519
 
515
- def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
516
- assert batch.forward_mode is not None
520
+ def sample(
521
+ self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
522
+ ) -> torch.Tensor:
523
+ # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
524
+ sampling_info = forward_batch.sampling_info
525
+ sampling_info.update_regex_vocab_mask()
526
+ sampling_info.update_penalties()
527
+ logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
517
528
 
518
- if self.is_multimodal_model and batch.forward_mode.is_extend():
519
- return self.forward_extend_multi_modal(batch)
520
- elif batch.forward_mode.is_decode():
521
- return self.forward_decode(batch)
522
- elif batch.forward_mode.is_extend():
523
- return self.forward_extend(batch)
524
- else:
525
- raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
529
+ # Sample the next tokens.
530
+ next_token_ids = self.sampler(logits, sampling_info)
531
+ return next_token_ids
526
532
 
527
- def _apply_logits_bias(
528
- self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
529
- ):
533
+ def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
530
534
  # Apply logit_bias
531
535
  if sampling_info.logit_bias is not None:
532
536
  logits.add_(sampling_info.logit_bias)
533
537
 
534
538
  # min-token, presence, frequency
535
539
  if sampling_info.linear_penalties is not None:
536
- logits += sampling_info.linear_penalties
540
+ logits.add_(sampling_info.linear_penalties)
537
541
 
538
542
  # repetition
539
543
  if sampling_info.scaling_penalties is not None:
@@ -549,20 +553,6 @@ class ModelRunner:
549
553
 
550
554
  return logits
551
555
 
552
- def sample(
553
- self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
554
- ) -> torch.Tensor:
555
- # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
556
- batch.sampling_info.update_regex_vocab_mask(batch)
557
- batch.sampling_info.update_penalties()
558
- logits = self._apply_logits_bias(
559
- logits_output.next_token_logits, batch.sampling_info
560
- )
561
-
562
- # Sample the next tokens.
563
- next_token_ids = self.sampler(logits, batch.sampling_info)
564
- return next_token_ids
565
-
566
556
 
567
557
  @lru_cache()
568
558
  def import_model_classes():
@@ -571,17 +561,25 @@ def import_model_classes():
571
561
  package = importlib.import_module(package_name)
572
562
  for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
573
563
  if not ispkg:
574
- module = importlib.import_module(name)
564
+ try:
565
+ module = importlib.import_module(name)
566
+ except Exception as e:
567
+ logger.warning(f"Ignore import error when loading {name}. " f"{e}")
568
+ continue
575
569
  if hasattr(module, "EntryClass"):
576
570
  entry = module.EntryClass
577
571
  if isinstance(
578
572
  entry, list
579
573
  ): # To support multiple model classes in one module
580
574
  for tmp in entry:
581
- assert tmp.__name__ not in model_arch_name_to_cls
575
+ assert (
576
+ tmp.__name__ not in model_arch_name_to_cls
577
+ ), f"Duplicated model implementation for {tmp.__name__}"
582
578
  model_arch_name_to_cls[tmp.__name__] = tmp
583
579
  else:
584
- assert entry.__name__ not in model_arch_name_to_cls
580
+ assert (
581
+ entry.__name__ not in model_arch_name_to_cls
582
+ ), f"Duplicated model implementation for {entry.__name__}"
585
583
  model_arch_name_to_cls[entry.__name__] = entry
586
584
 
587
585
  return model_arch_name_to_cls
@@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
50
 
51
51
 
52
52
  def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module):
189
189
  self,
190
190
  positions: torch.Tensor,
191
191
  hidden_states: torch.Tensor,
192
- input_metadata: InputMetadata,
192
+ forward_batch: ForwardBatch,
193
193
  ) -> torch.Tensor:
194
194
  qkv, _ = self.W_pack(hidden_states)
195
195
  q, k, v = qkv.chunk(chunks=3, dim=-1)
196
196
  if self.postion_embedding != "ALIBI":
197
197
  q, k = self.rotary_emb(positions, q, k)
198
- attn_output = self.attn(q, k, v, input_metadata)
198
+ attn_output = self.attn(q, k, v, forward_batch)
199
199
  output, _ = self.o_proj(attn_output)
200
200
  return output
201
201
 
@@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module):
237
237
  self,
238
238
  positions: torch.Tensor,
239
239
  hidden_states: torch.Tensor,
240
- input_metadata: InputMetadata,
240
+ forward_batch: ForwardBatch,
241
241
  residual: Optional[torch.Tensor],
242
242
  ) -> Tuple[torch.Tensor, torch.Tensor]:
243
243
  # Self Attention
@@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module):
249
249
  hidden_states = self.self_attn(
250
250
  positions=positions,
251
251
  hidden_states=hidden_states,
252
- input_metadata=input_metadata,
252
+ forward_batch=forward_batch,
253
253
  )
254
254
 
255
255
  # Fully Connected
@@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module):
292
292
  self,
293
293
  input_ids: torch.Tensor,
294
294
  positions: torch.Tensor,
295
- input_metadata: InputMetadata,
295
+ forward_batch: ForwardBatch,
296
296
  ) -> torch.Tensor:
297
297
  hidden_states = self.embed_tokens(input_ids)
298
298
  residual = None
@@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module):
301
301
  hidden_states, residual = layer(
302
302
  positions,
303
303
  hidden_states,
304
- input_metadata,
304
+ forward_batch,
305
305
  residual,
306
306
  )
307
307
  hidden_states, _ = self.norm(hidden_states, residual)
@@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
350
350
  self,
351
351
  input_ids: torch.Tensor,
352
352
  positions: torch.Tensor,
353
- input_metadata: InputMetadata,
353
+ forward_batch: ForwardBatch,
354
354
  ) -> torch.Tensor:
355
- hidden_states = self.model(input_ids, positions, input_metadata)
355
+ hidden_states = self.model(input_ids, positions, forward_batch)
356
356
  return self.logits_processor(
357
- input_ids, hidden_states, self.lm_head.weight, input_metadata
357
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
358
358
  )
359
359
 
360
360
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -42,7 +42,7 @@ from sglang.srt.layers.linear import (
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
44
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
46
 
47
47
  LoraConfig = None
48
48
 
@@ -118,7 +118,7 @@ class GLMAttention(nn.Module):
118
118
  self,
119
119
  hidden_states: torch.Tensor,
120
120
  position_ids: torch.Tensor,
121
- input_metadata: InputMetadata,
121
+ forward_batch: ForwardBatch,
122
122
  ) -> torch.Tensor:
123
123
  qkv, _ = self.query_key_value(hidden_states)
124
124
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -127,7 +127,7 @@ class GLMAttention(nn.Module):
127
127
  q,
128
128
  k,
129
129
  v,
130
- input_metadata,
130
+ forward_batch,
131
131
  )
132
132
  attn_output, _ = self.dense(context_layer)
133
133
  return attn_output
@@ -220,7 +220,7 @@ class GLMBlock(nn.Module):
220
220
  self,
221
221
  hidden_states: torch.Tensor,
222
222
  position_ids: torch.Tensor,
223
- input_metadata: InputMetadata,
223
+ forward_batch: ForwardBatch,
224
224
  ) -> torch.Tensor:
225
225
  # hidden_states: [num_tokens, h]
226
226
  # Layer norm at the beginning of the transformer layer.
@@ -229,7 +229,7 @@ class GLMBlock(nn.Module):
229
229
  attention_output = self.self_attention(
230
230
  hidden_states=layernorm_output,
231
231
  position_ids=position_ids,
232
- input_metadata=input_metadata,
232
+ forward_batch=forward_batch,
233
233
  )
234
234
 
235
235
  # Residual connection.
@@ -288,14 +288,14 @@ class GLMTransformer(nn.Module):
288
288
  self,
289
289
  hidden_states: torch.Tensor,
290
290
  position_ids: torch.Tensor,
291
- input_metadata: InputMetadata,
291
+ forward_batch: ForwardBatch,
292
292
  ) -> torch.Tensor:
293
293
  for i in range(self.num_layers):
294
294
  layer = self.layers[i]
295
295
  hidden_states = layer(
296
296
  hidden_states=hidden_states,
297
297
  position_ids=position_ids,
298
- input_metadata=input_metadata,
298
+ forward_batch=forward_batch,
299
299
  )
300
300
  # Final layer norm.
301
301
  if self.post_layer_norm:
@@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module):
328
328
  self,
329
329
  input_ids: torch.Tensor,
330
330
  position_ids: torch.Tensor,
331
- input_metadata: InputMetadata,
331
+ forward_batch: ForwardBatch,
332
332
  ) -> torch.Tensor:
333
333
  inputs_embeds = self.embedding(input_ids)
334
334
 
@@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module):
336
336
  hidden_states = self.encoder(
337
337
  hidden_states=inputs_embeds,
338
338
  position_ids=position_ids,
339
- input_metadata=input_metadata,
339
+ forward_batch=forward_batch,
340
340
  )
341
341
  return hidden_states
342
342
 
@@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module):
376
376
  self,
377
377
  input_ids: torch.Tensor,
378
378
  positions: torch.Tensor,
379
- input_metadata: InputMetadata,
379
+ forward_batch: ForwardBatch,
380
380
  ) -> torch.Tensor:
381
- hidden_states = self.transformer(input_ids, positions, input_metadata)
381
+ hidden_states = self.transformer(input_ids, positions, forward_batch)
382
382
  return self.logits_processor(
383
- input_ids, hidden_states, self.lm_head.weight, input_metadata
383
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
384
384
  )
385
385
 
386
386
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import (
63
63
  from sglang.srt.layers.logits_processor import LogitsProcessor
64
64
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
65
65
  from sglang.srt.layers.radix_attention import RadixAttention
66
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
66
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
67
67
  from sglang.srt.utils import set_weight_attrs
68
68
 
69
69
 
@@ -220,14 +220,14 @@ class CohereAttention(nn.Module):
220
220
  self,
221
221
  positions: torch.Tensor,
222
222
  hidden_states: torch.Tensor,
223
- input_metadata: InputMetadata,
223
+ forward_batch: ForwardBatch,
224
224
  ) -> torch.Tensor:
225
225
  qkv, _ = self.qkv_proj(hidden_states)
226
226
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
227
227
  if self.use_qk_norm:
228
228
  q, k = self._apply_qk_norm(q, k)
229
229
  q, k = self.rotary_emb(positions, q, k)
230
- attn_output = self.attn(q, k, v, input_metadata)
230
+ attn_output = self.attn(q, k, v, forward_batch)
231
231
  output, _ = self.o_proj(attn_output)
232
232
  return output
233
233
 
@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module):
255
255
  self,
256
256
  positions: torch.Tensor,
257
257
  hidden_states: torch.Tensor,
258
- input_metadata: InputMetadata,
258
+ forward_batch: ForwardBatch,
259
259
  residual: Optional[torch.Tensor],
260
260
  ) -> Tuple[torch.Tensor, torch.Tensor]:
261
261
  # Self Attention
@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module):
264
264
  hidden_states_attention = self.self_attn(
265
265
  positions=positions,
266
266
  hidden_states=hidden_states,
267
- input_metadata=input_metadata,
267
+ forward_batch=forward_batch,
268
268
  )
269
269
  hidden_states_mlp = self.mlp(hidden_states)
270
270
  # Add everything together
@@ -299,7 +299,7 @@ class CohereModel(nn.Module):
299
299
  self,
300
300
  input_ids: torch.Tensor,
301
301
  positions: torch.Tensor,
302
- input_metadata: InputMetadata,
302
+ forward_batch: ForwardBatch,
303
303
  ) -> torch.Tensor:
304
304
  hidden_states = self.embed_tokens(input_ids)
305
305
  residual = None
@@ -308,7 +308,7 @@ class CohereModel(nn.Module):
308
308
  hidden_states, residual = layer(
309
309
  positions,
310
310
  hidden_states,
311
- input_metadata,
311
+ forward_batch,
312
312
  residual,
313
313
  )
314
314
  hidden_states, _ = self.norm(hidden_states, residual)
@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module):
333
333
  self,
334
334
  input_ids: torch.Tensor,
335
335
  positions: torch.Tensor,
336
- input_metadata: InputMetadata,
336
+ forward_batch: ForwardBatch,
337
337
  ) -> torch.Tensor:
338
338
  hidden_states = self.model(
339
339
  input_ids,
340
340
  positions,
341
- input_metadata,
341
+ forward_batch,
342
342
  )
343
343
  return self.logits_processor(
344
- input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
344
+ input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
345
345
  )
346
346
 
347
347
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/dbrx.py CHANGED
@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import (
44
44
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
45
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
46
  from sglang.srt.layers.radix_attention import RadixAttention
47
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
48
  from sglang.srt.utils import set_weight_attrs
49
49
 
50
50
 
@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module):
249
249
  self,
250
250
  position_ids: torch.Tensor,
251
251
  hidden_states: torch.Tensor,
252
- input_metadata: InputMetadata,
252
+ forward_batch: ForwardBatch,
253
253
  ) -> torch.Tensor:
254
254
  qkv, _ = self.Wqkv(hidden_states)
255
255
  if self.clip_qkv is not None:
256
256
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
257
257
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
258
258
  q, k = self.rotary_emb(position_ids, q, k)
259
- attn_output = self.attn(q, k, v, input_metadata)
259
+ attn_output = self.attn(q, k, v, forward_batch)
260
260
  hidden_states, _ = self.out_proj(attn_output)
261
261
  return hidden_states
262
262
 
@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module):
278
278
  self,
279
279
  position_ids: torch.Tensor,
280
280
  hidden_states: torch.Tensor,
281
- input_metadata: InputMetadata,
281
+ forward_batch: ForwardBatch,
282
282
  ) -> torch.Tensor:
283
283
  residual = hidden_states
284
284
  hidden_states = self.norm_1(hidden_states)
285
285
  x = self.attn(
286
286
  position_ids=position_ids,
287
287
  hidden_states=hidden_states,
288
- input_metadata=input_metadata,
288
+ forward_batch=forward_batch,
289
289
  )
290
290
  hidden_states = residual + x
291
291
  residual = hidden_states
@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module):
310
310
  self,
311
311
  position_ids: torch.Tensor,
312
312
  hidden_states: torch.Tensor,
313
- input_metadata: InputMetadata,
313
+ forward_batch: ForwardBatch,
314
314
  ) -> torch.Tensor:
315
315
  hidden_states, residual = self.norm_attn_norm(
316
316
  position_ids=position_ids,
317
317
  hidden_states=hidden_states,
318
- input_metadata=input_metadata,
318
+ forward_batch=forward_batch,
319
319
  )
320
320
  hidden_states = self.ffn(hidden_states)
321
321
  hidden_states = hidden_states + residual
@@ -349,7 +349,7 @@ class DbrxModel(nn.Module):
349
349
  self,
350
350
  input_ids: torch.Tensor,
351
351
  position_ids: torch.Tensor,
352
- input_metadata: InputMetadata,
352
+ forward_batch: ForwardBatch,
353
353
  input_embeds: torch.Tensor = None,
354
354
  ) -> torch.Tensor:
355
355
  if input_embeds is None:
@@ -358,7 +358,7 @@ class DbrxModel(nn.Module):
358
358
  hidden_states = input_embeds
359
359
  for i in range(len(self.blocks)):
360
360
  block = self.blocks[i]
361
- hidden_states = block(position_ids, hidden_states, input_metadata)
361
+ hidden_states = block(position_ids, hidden_states, forward_batch)
362
362
  hidden_states = self.norm_f(hidden_states)
363
363
  return hidden_states
364
364
 
@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module):
388
388
  self,
389
389
  input_ids: torch.Tensor,
390
390
  positions: torch.Tensor,
391
- input_metadata: InputMetadata,
391
+ forward_batch: ForwardBatch,
392
392
  ) -> torch.Tensor:
393
- hidden_states = self.transformer(input_ids, positions, input_metadata)
393
+ hidden_states = self.transformer(input_ids, positions, forward_batch)
394
394
  return self.logits_processor(
395
- input_ids, hidden_states, self.lm_head.weight, input_metadata
395
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
396
396
  )
397
397
 
398
398
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):