sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ limitations under the License.
18
18
  import gc
19
19
  import importlib
20
20
  import importlib.resources
21
+ import json
21
22
  import logging
22
23
  import pkgutil
23
24
  from functools import lru_cache
@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
39
40
 
40
41
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
42
  from sglang.srt.constrained import disable_cache
43
+ from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
42
44
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
43
45
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
44
46
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
46
48
  from sglang.srt.lora.lora_manager import LoRAManager
47
49
  from sglang.srt.managers.schedule_batch import global_server_args_dict
48
50
  from sglang.srt.mem_cache.memory_pool import (
51
+ DoubleSparseTokenToKVPool,
49
52
  MHATokenToKVPool,
50
53
  MLATokenToKVPool,
51
54
  ReqToTokenPool,
@@ -81,10 +84,11 @@ class ModelRunner:
81
84
  # Parse args
82
85
  self.model_config = model_config
83
86
  self.mem_fraction_static = mem_fraction_static
87
+ self.device = server_args.device
84
88
  self.gpu_id = gpu_id
85
89
  self.tp_rank = tp_rank
86
90
  self.tp_size = tp_size
87
- self.nccl_port = nccl_port
91
+ self.dist_port = nccl_port
88
92
  self.server_args = server_args
89
93
  self.is_multimodal_model = is_multimodal_model(
90
94
  self.model_config.hf_config.architectures
@@ -95,9 +99,23 @@ class ModelRunner:
95
99
  self.model_config.attention_arch == AttentionArch.MLA
96
100
  and not self.server_args.disable_mla
97
101
  ):
98
- logger.info("MLA optimization is tunred on. Use triton backend.")
102
+ logger.info("MLA optimization is turned on. Use triton backend.")
99
103
  self.server_args.attention_backend = "triton"
100
104
 
105
+ if self.server_args.enable_double_sparsity:
106
+ logger.info(
107
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
108
+ )
109
+ self.server_args.attention_backend = "triton"
110
+ self.server_args.disable_cuda_graph = True
111
+ if self.server_args.ds_heavy_channel_type is None:
112
+ raise ValueError(
113
+ "Please specify the heavy channel type for double sparsity optimization."
114
+ )
115
+ self.init_double_sparsity_channel_config(
116
+ self.server_args.ds_heavy_channel_type
117
+ )
118
+
101
119
  if self.is_multimodal_model:
102
120
  logger.info(
103
121
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
@@ -118,6 +136,8 @@ class ModelRunner:
118
136
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
119
137
  "disable_mla": server_args.disable_mla,
120
138
  "torchao_config": server_args.torchao_config,
139
+ "disable_penalizer": server_args.disable_penalizer,
140
+ "disable_nan_detection": server_args.disable_nan_detection,
121
141
  }
122
142
  )
123
143
 
@@ -132,39 +152,51 @@ class ModelRunner:
132
152
  server_args.max_running_requests,
133
153
  server_args.max_total_tokens,
134
154
  )
135
- self.init_cublas()
136
- self.init_attention_backend()
137
- self.init_cuda_graphs()
155
+ if self.device == "cuda":
156
+ self.init_cublas()
157
+ self.init_attention_backend()
158
+ self.init_cuda_graphs()
159
+ else:
160
+ self.cuda_graph_runner = None
161
+ self.init_attention_backend()
138
162
 
139
163
  def init_torch_distributed(self):
164
+ logger.info("Init torch distributed begin.")
140
165
  # Init torch distributed
141
- torch.cuda.set_device(self.gpu_id)
142
- logger.info("Init nccl begin.")
166
+ if self.device == "cuda":
167
+ torch.cuda.set_device(self.gpu_id)
168
+ backend = "nccl"
169
+ # ToDO(liangan1):Just use gloo to bypass the initilization fail
170
+ # Need to use xccl for xpu backend in the future
171
+ elif self.device == "xpu":
172
+ torch.xpu.set_device(self.gpu_id)
173
+ backend = "gloo"
143
174
 
144
175
  if not self.server_args.enable_p2p_check:
145
176
  monkey_patch_vllm_p2p_access_check(self.gpu_id)
146
-
147
177
  if self.server_args.dist_init_addr:
148
- nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
178
+ dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
149
179
  else:
150
- nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
180
+ dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
151
181
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
152
182
  init_distributed_environment(
153
- backend="nccl",
183
+ backend=backend,
154
184
  world_size=self.tp_size,
155
185
  rank=self.tp_rank,
156
186
  local_rank=self.gpu_id,
157
- distributed_init_method=nccl_init_method,
187
+ distributed_init_method=dist_init_method,
158
188
  )
159
189
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
160
190
  min_per_gpu_memory = get_available_gpu_memory(
161
- self.gpu_id, distributed=self.tp_size > 1
191
+ self.device, self.gpu_id, distributed=self.tp_size > 1
162
192
  )
163
193
  self.tp_group = get_tp_group()
164
194
 
165
195
  # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
166
196
  # so we disable padding in cuda graph.
167
- if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
197
+ if self.device == "cuda" and not all(
198
+ in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
199
+ ):
168
200
  self.server_args.disable_cuda_graph_padding = True
169
201
  logger.info(
170
202
  "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
@@ -172,7 +204,7 @@ class ModelRunner:
172
204
 
173
205
  # Check memory for tensor parallelism
174
206
  if self.tp_size > 1:
175
- local_gpu_memory = get_available_gpu_memory(self.gpu_id)
207
+ local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
176
208
  if min_per_gpu_memory < local_gpu_memory * 0.9:
177
209
  raise ValueError(
178
210
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
@@ -182,23 +214,22 @@ class ModelRunner:
182
214
 
183
215
  def load_model(self):
184
216
  logger.info(
185
- f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
217
+ f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
186
218
  )
187
219
 
188
220
  # This can reduce thread conflicts and speed up weight loading.
189
221
  torch.set_num_threads(1)
190
-
191
- if torch.cuda.get_device_capability()[0] < 8:
192
- logger.info(
193
- "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
194
- )
195
- self.server_args.dtype = "float16"
196
- if torch.cuda.get_device_capability()[1] < 5:
197
- raise RuntimeError("SGLang only supports sm75 and above.")
222
+ if self.device == "cuda":
223
+ if torch.cuda.get_device_capability()[0] < 8:
224
+ logger.info(
225
+ "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
226
+ )
227
+ self.server_args.dtype = "float16"
228
+ if torch.cuda.get_device_capability()[1] < 5:
229
+ raise RuntimeError("SGLang only supports sm75 and above.")
198
230
 
199
231
  # Prepare the vllm model config
200
232
  monkey_patch_vllm_dummy_weight_loader()
201
- self.device_config = DeviceConfig()
202
233
  self.load_config = LoadConfig(load_format=self.server_args.load_format)
203
234
  self.vllm_model_config = VllmModelConfig(
204
235
  model=self.server_args.model_path,
@@ -220,7 +251,7 @@ class ModelRunner:
220
251
  self.model = get_model(
221
252
  model_config=self.vllm_model_config,
222
253
  load_config=self.load_config,
223
- device_config=self.device_config,
254
+ device_config=DeviceConfig(self.device),
224
255
  parallel_config=None,
225
256
  scheduler_config=None,
226
257
  lora_config=None,
@@ -240,7 +271,7 @@ class ModelRunner:
240
271
  f"Load weight end. "
241
272
  f"type={type(self.model).__name__}, "
242
273
  f"dtype={self.dtype}, "
243
- f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
274
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
244
275
  )
245
276
 
246
277
  def update_weights(self, model_path: str, load_format: str):
@@ -254,10 +285,10 @@ class ModelRunner:
254
285
 
255
286
  logger.info(
256
287
  f"Update weights begin. "
257
- f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
288
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
258
289
  )
259
290
 
260
- target_device = torch.device(self.device_config.device)
291
+ target_device = torch.device(self.device)
261
292
 
262
293
  try:
263
294
  # TODO: Use a better method to check this
@@ -343,7 +374,7 @@ class ModelRunner:
343
374
 
344
375
  def profile_max_num_token(self, total_gpu_memory: int):
345
376
  available_gpu_memory = get_available_gpu_memory(
346
- self.gpu_id, distributed=self.tp_size > 1
377
+ self.device, self.gpu_id, distributed=self.tp_size > 1
347
378
  )
348
379
  if (
349
380
  self.model_config.attention_arch == AttentionArch.MLA
@@ -409,11 +440,10 @@ class ModelRunner:
409
440
  4096,
410
441
  )
411
442
 
412
- device = "cuda"
413
443
  self.req_to_token_pool = ReqToTokenPool(
414
444
  size=max_num_reqs + 1,
415
445
  max_context_len=self.model_config.context_len + 4,
416
- device=device,
446
+ device=self.device,
417
447
  )
418
448
  if (
419
449
  self.model_config.attention_arch == AttentionArch.MLA
@@ -425,7 +455,17 @@ class ModelRunner:
425
455
  kv_lora_rank=self.model_config.kv_lora_rank,
426
456
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
427
457
  layer_num=self.model_config.num_hidden_layers,
428
- device=device,
458
+ device=self.device,
459
+ )
460
+ elif self.server_args.enable_double_sparsity:
461
+ self.token_to_kv_pool = DoubleSparseTokenToKVPool(
462
+ self.max_total_num_tokens,
463
+ dtype=self.kv_cache_dtype,
464
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
465
+ head_dim=self.model_config.head_dim,
466
+ layer_num=self.model_config.num_hidden_layers,
467
+ device=self.device,
468
+ heavy_channel_num=self.server_args.ds_heavy_channel_num,
429
469
  )
430
470
  else:
431
471
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -434,11 +474,11 @@ class ModelRunner:
434
474
  head_num=self.model_config.get_num_kv_heads(self.tp_size),
435
475
  head_dim=self.model_config.head_dim,
436
476
  layer_num=self.model_config.num_hidden_layers,
437
- device=device,
477
+ device=self.device,
438
478
  )
439
479
  logger.info(
440
480
  f"Memory pool end. "
441
- f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
481
+ f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
442
482
  )
443
483
 
444
484
  def init_cublas(self):
@@ -463,12 +503,33 @@ class ModelRunner:
463
503
  "Cross attention is not supported in the triton attention backend. "
464
504
  "Please use `--attention-backend flashinfer`."
465
505
  )
466
- self.attn_backend = TritonAttnBackend(self)
506
+ if self.server_args.enable_double_sparsity:
507
+ self.attn_backend = DoubleSparseAttnBackend(self)
508
+ else:
509
+ self.attn_backend = TritonAttnBackend(self)
467
510
  else:
468
511
  raise ValueError(
469
512
  f"Invalid attention backend: {self.server_args.attention_backend}"
470
513
  )
471
514
 
515
+ def init_double_sparsity_channel_config(self, selected_channel):
516
+
517
+ selected_channel = "." + selected_channel + "_proj"
518
+ self.sorted_channels = []
519
+ # load channel config
520
+ with open(self.server_args.ds_channel_config_path, "r") as f:
521
+ channel_config = json.load(f)
522
+
523
+ for i in range(self.model_config.num_hidden_layers):
524
+ key = "model.layers." + str(i) + ".self_attn" + selected_channel
525
+ self.sorted_channels.append(
526
+ torch.tensor(channel_config[key])[
527
+ :, : self.server_args.ds_heavy_channel_num
528
+ ]
529
+ .contiguous()
530
+ .cuda()
531
+ )
532
+
472
533
  def init_cuda_graphs(self):
473
534
  """Capture cuda graphs."""
474
535
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
@@ -491,11 +552,14 @@ class ModelRunner:
491
552
  ):
492
553
  return self.cuda_graph_runner.replay(forward_batch)
493
554
 
555
+ forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
556
+ self.attn_backend.init_forward_metadata(forward_batch)
494
557
  return self.model.forward(
495
558
  forward_batch.input_ids, forward_batch.positions, forward_batch
496
559
  )
497
560
 
498
561
  def forward_extend(self, forward_batch: ForwardBatch):
562
+ self.attn_backend.init_forward_metadata(forward_batch)
499
563
  if self.is_generation:
500
564
  return self.model.forward(
501
565
  forward_batch.input_ids, forward_batch.positions, forward_batch
@@ -24,7 +24,6 @@ from typing import Iterable, Optional, Tuple
24
24
  import torch
25
25
  from torch import nn
26
26
  from transformers import PretrainedConfig
27
- from vllm.config import CacheConfig
28
27
  from vllm.distributed import (
29
28
  get_tensor_model_parallel_rank,
30
29
  get_tensor_model_parallel_world_size,
@@ -330,7 +329,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
330
329
  self,
331
330
  config: PretrainedConfig,
332
331
  position_embedding: str,
333
- cache_config: Optional[CacheConfig] = None,
332
+ cache_config=None,
334
333
  quant_config: Optional[QuantizationConfig] = None,
335
334
  ):
336
335
  super().__init__()
@@ -404,7 +403,7 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
404
403
  def __init__(
405
404
  self,
406
405
  config,
407
- cache_config: Optional[CacheConfig] = None,
406
+ cache_config=None,
408
407
  quant_config: Optional[QuantizationConfig] = None,
409
408
  ):
410
409
  if config.hidden_size == 4096: # baichuan2 7b
@@ -22,7 +22,6 @@ from typing import Iterable, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from torch.nn import LayerNorm
25
- from vllm.config import CacheConfig
26
25
  from vllm.distributed import get_tensor_model_parallel_world_size
27
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
27
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -52,7 +51,7 @@ class GLMAttention(nn.Module):
52
51
  self,
53
52
  config,
54
53
  layer_id: int = 0,
55
- cache_config: Optional[CacheConfig] = None,
54
+ cache_config=None,
56
55
  quant_config: Optional[QuantizationConfig] = None,
57
56
  ):
58
57
  super().__init__()
@@ -188,7 +187,7 @@ class GLMBlock(nn.Module):
188
187
  self,
189
188
  config,
190
189
  layer_id: int,
191
- cache_config: Optional[CacheConfig] = None,
190
+ cache_config=None,
192
191
  quant_config: Optional[QuantizationConfig] = None,
193
192
  ):
194
193
  super().__init__()
@@ -260,7 +259,7 @@ class GLMTransformer(nn.Module):
260
259
  def __init__(
261
260
  self,
262
261
  config,
263
- cache_config: Optional[CacheConfig] = None,
262
+ cache_config=None,
264
263
  quant_config: Optional[QuantizationConfig] = None,
265
264
  ):
266
265
  super().__init__()
@@ -308,7 +307,7 @@ class ChatGLMModel(nn.Module):
308
307
  def __init__(
309
308
  self,
310
309
  config,
311
- cache_config: Optional[CacheConfig] = None,
310
+ cache_config=None,
312
311
  quant_config: Optional[QuantizationConfig] = None,
313
312
  ):
314
313
  super().__init__()
@@ -359,7 +358,7 @@ class ChatGLMForCausalLM(nn.Module):
359
358
  def __init__(
360
359
  self,
361
360
  config: ChatGLMConfig,
362
- cache_config: Optional[CacheConfig] = None,
361
+ cache_config=None,
363
362
  quant_config: Optional[QuantizationConfig] = None,
364
363
  lora_config: Optional[LoraConfig] = None,
365
364
  ):
@@ -45,7 +45,6 @@ import torch.utils.checkpoint
45
45
  from torch import nn
46
46
  from torch.nn.parameter import Parameter
47
47
  from transformers import PretrainedConfig
48
- from vllm.config import CacheConfig
49
48
  from vllm.distributed import (
50
49
  get_tensor_model_parallel_rank,
51
50
  get_tensor_model_parallel_world_size,
@@ -320,7 +319,7 @@ class CohereForCausalLM(nn.Module):
320
319
  self,
321
320
  config: PretrainedConfig,
322
321
  quant_config: Optional[QuantizationConfig] = None,
323
- cache_config: Optional[CacheConfig] = None,
322
+ cache_config=None,
324
323
  ) -> None:
325
324
  super().__init__()
326
325
  self.config = config
sglang/srt/models/dbrx.py CHANGED
@@ -20,7 +20,6 @@ from typing import Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  import torch.nn as nn
23
- from vllm.config import CacheConfig
24
23
  from vllm.distributed import (
25
24
  get_tensor_model_parallel_rank,
26
25
  get_tensor_model_parallel_world_size,
@@ -368,7 +367,7 @@ class DbrxForCausalLM(nn.Module):
368
367
  self,
369
368
  config: DbrxConfig,
370
369
  quant_config: Optional[QuantizationConfig] = None,
371
- cache_config: Optional[CacheConfig] = None,
370
+ cache_config=None,
372
371
  ):
373
372
  super().__init__()
374
373
  self.config = config
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import (
26
25
  get_tensor_model_parallel_rank,
27
26
  get_tensor_model_parallel_world_size,
@@ -185,7 +184,7 @@ class DeepseekAttention(nn.Module):
185
184
  rope_theta: float = 10000,
186
185
  rope_scaling: Optional[Dict[str, Any]] = None,
187
186
  max_position_embeddings: int = 8192,
188
- cache_config: Optional[CacheConfig] = None,
187
+ cache_config=None,
189
188
  quant_config: Optional[QuantizationConfig] = None,
190
189
  ) -> None:
191
190
  super().__init__()
@@ -262,7 +261,7 @@ class DeepseekDecoderLayer(nn.Module):
262
261
  self,
263
262
  config: PretrainedConfig,
264
263
  layer_id: int,
265
- cache_config: Optional[CacheConfig] = None,
264
+ cache_config=None,
266
265
  quant_config: Optional[QuantizationConfig] = None,
267
266
  ) -> None:
268
267
  super().__init__()
@@ -331,7 +330,7 @@ class DeepseekModel(nn.Module):
331
330
  def __init__(
332
331
  self,
333
332
  config: PretrainedConfig,
334
- cache_config: Optional[CacheConfig] = None,
333
+ cache_config=None,
335
334
  quant_config: Optional[QuantizationConfig] = None,
336
335
  ) -> None:
337
336
  super().__init__()
@@ -374,7 +373,7 @@ class DeepseekForCausalLM(nn.Module):
374
373
  def __init__(
375
374
  self,
376
375
  config: PretrainedConfig,
377
- cache_config: Optional[CacheConfig] = None,
376
+ cache_config=None,
378
377
  quant_config: Optional[QuantizationConfig] = None,
379
378
  ) -> None:
380
379
  super().__init__()
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import (
26
25
  get_tensor_model_parallel_world_size,
27
26
  tensor_model_parallel_all_reduce,
@@ -188,7 +187,7 @@ class DeepseekV2Attention(nn.Module):
188
187
  rope_theta: float = 10000,
189
188
  rope_scaling: Optional[Dict[str, Any]] = None,
190
189
  max_position_embeddings: int = 8192,
191
- cache_config: Optional[CacheConfig] = None,
190
+ cache_config=None,
192
191
  quant_config: Optional[QuantizationConfig] = None,
193
192
  layer_id=None,
194
193
  ) -> None:
@@ -336,7 +335,7 @@ class DeepseekV2AttentionMLA(nn.Module):
336
335
  rope_theta: float = 10000,
337
336
  rope_scaling: Optional[Dict[str, Any]] = None,
338
337
  max_position_embeddings: int = 8192,
339
- cache_config: Optional[CacheConfig] = None,
338
+ cache_config=None,
340
339
  quant_config: Optional[QuantizationConfig] = None,
341
340
  layer_id=None,
342
341
  ) -> None:
@@ -498,7 +497,7 @@ class DeepseekV2DecoderLayer(nn.Module):
498
497
  self,
499
498
  config: PretrainedConfig,
500
499
  layer_id: int,
501
- cache_config: Optional[CacheConfig] = None,
500
+ cache_config=None,
502
501
  quant_config: Optional[QuantizationConfig] = None,
503
502
  ) -> None:
504
503
  super().__init__()
@@ -594,7 +593,7 @@ class DeepseekV2Model(nn.Module):
594
593
  def __init__(
595
594
  self,
596
595
  config: PretrainedConfig,
597
- cache_config: Optional[CacheConfig] = None,
596
+ cache_config=None,
598
597
  quant_config: Optional[QuantizationConfig] = None,
599
598
  ) -> None:
600
599
  super().__init__()
@@ -640,7 +639,7 @@ class DeepseekV2ForCausalLM(nn.Module):
640
639
  def __init__(
641
640
  self,
642
641
  config: PretrainedConfig,
643
- cache_config: Optional[CacheConfig] = None,
642
+ cache_config=None,
644
643
  quant_config: Optional[QuantizationConfig] = None,
645
644
  ) -> None:
646
645
  super().__init__()
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
26
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -295,7 +294,7 @@ class ExaoneForCausalLM(nn.Module):
295
294
  self,
296
295
  config,
297
296
  quant_config: Optional[QuantizationConfig] = None,
298
- cache_config: Optional[CacheConfig] = None,
297
+ cache_config=None,
299
298
  ) -> None:
300
299
  super().__init__()
301
300
  self.config = config
@@ -21,7 +21,7 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig, LoRAConfig
24
+ from vllm.config import LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
27
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -279,7 +279,7 @@ class GemmaForCausalLM(nn.Module):
279
279
  config: PretrainedConfig,
280
280
  quant_config: Optional[QuantizationConfig] = None,
281
281
  lora_config: Optional[LoRAConfig] = None,
282
- cache_config: Optional[CacheConfig] = None,
282
+ cache_config=None,
283
283
  ) -> None:
284
284
  del lora_config # Unused.
285
285
  super().__init__()
@@ -20,7 +20,7 @@ from typing import Iterable, Optional, Set, Tuple, Union
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import PretrainedConfig
23
- from vllm.config import CacheConfig, LoRAConfig
23
+ from vllm.config import LoRAConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
 
26
26
  # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
@@ -105,7 +105,7 @@ class Gemma2Attention(nn.Module):
105
105
  head_dim: int,
106
106
  max_position_embeddings: int,
107
107
  rope_theta: float,
108
- cache_config: Optional[CacheConfig] = None,
108
+ cache_config=None,
109
109
  quant_config: Optional[QuantizationConfig] = None,
110
110
  ) -> None:
111
111
  super().__init__()
@@ -190,7 +190,7 @@ class Gemma2DecoderLayer(nn.Module):
190
190
  self,
191
191
  layer_idx: int,
192
192
  config: PretrainedConfig,
193
- cache_config: Optional[CacheConfig] = None,
193
+ cache_config=None,
194
194
  quant_config: Optional[QuantizationConfig] = None,
195
195
  ) -> None:
196
196
  super().__init__()
@@ -257,7 +257,7 @@ class Gemma2Model(nn.Module):
257
257
  def __init__(
258
258
  self,
259
259
  config: PretrainedConfig,
260
- cache_config: Optional[CacheConfig] = None,
260
+ cache_config=None,
261
261
  quant_config: Optional[QuantizationConfig] = None,
262
262
  ) -> None:
263
263
  super().__init__()
@@ -336,7 +336,7 @@ class Gemma2ForCausalLM(nn.Module):
336
336
  def __init__(
337
337
  self,
338
338
  config: PretrainedConfig,
339
- cache_config: Optional[CacheConfig] = None,
339
+ cache_config=None,
340
340
  quant_config: Optional[QuantizationConfig] = None,
341
341
  lora_config: Optional[LoRAConfig] = None,
342
342
  ) -> None:
@@ -21,7 +21,7 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
- from vllm.config import CacheConfig, LoRAConfig
24
+ from vllm.config import LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
27
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -44,7 +44,7 @@ class GPTBigCodeAttention(nn.Module):
44
44
  self,
45
45
  layer_id: int,
46
46
  config: GPTBigCodeConfig,
47
- cache_config: Optional[CacheConfig] = None,
47
+ cache_config=None,
48
48
  quant_config: Optional[QuantizationConfig] = None,
49
49
  ):
50
50
  super().__init__()
@@ -145,7 +145,7 @@ class GPTBigCodeBlock(nn.Module):
145
145
  self,
146
146
  layer_id: int,
147
147
  config: GPTBigCodeConfig,
148
- cache_config: Optional[CacheConfig] = None,
148
+ cache_config=None,
149
149
  quant_config: Optional[QuantizationConfig] = None,
150
150
  ):
151
151
  super().__init__()
@@ -183,7 +183,7 @@ class GPTBigCodeModel(nn.Module):
183
183
  def __init__(
184
184
  self,
185
185
  config: GPTBigCodeConfig,
186
- cache_config: Optional[CacheConfig] = None,
186
+ cache_config=None,
187
187
  quant_config: Optional[QuantizationConfig] = None,
188
188
  lora_config: Optional[LoRAConfig] = None,
189
189
  ):
@@ -243,7 +243,7 @@ class GPTBigCodeForCausalLM(nn.Module):
243
243
  def __init__(
244
244
  self,
245
245
  config: GPTBigCodeConfig,
246
- cache_config: Optional[CacheConfig] = None,
246
+ cache_config=None,
247
247
  quant_config: Optional[QuantizationConfig] = None,
248
248
  lora_config: Optional[LoRAConfig] = None,
249
249
  ):
sglang/srt/models/grok.py CHANGED
@@ -23,7 +23,6 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import PretrainedConfig
26
- from vllm.config import CacheConfig
27
26
  from vllm.distributed import (
28
27
  get_tensor_model_parallel_rank,
29
28
  get_tensor_model_parallel_world_size,
@@ -289,7 +288,7 @@ class Grok1ForCausalLM(nn.Module):
289
288
  self,
290
289
  config: PretrainedConfig,
291
290
  quant_config: Optional[QuantizationConfig] = None,
292
- cache_config: Optional[CacheConfig] = None,
291
+ cache_config=None,
293
292
  ) -> None:
294
293
  super().__init__()
295
294
  self.config = config
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
26
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -254,7 +253,7 @@ class InternLM2ForCausalLM(nn.Module):
254
253
  self,
255
254
  config: PretrainedConfig,
256
255
  quant_config: Optional[QuantizationConfig] = None,
257
- cache_config: Optional[CacheConfig] = None,
256
+ cache_config=None,
258
257
  ) -> None:
259
258
  super().__init__()
260
259
  self.config = config