sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_latency.py +28 -10
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/layers/attention/__init__.py +27 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  7. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  8. sglang/srt/layers/attention/triton_backend.py +6 -4
  9. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  12. sglang/srt/layers/sampler.py +6 -2
  13. sglang/srt/managers/detokenizer_manager.py +31 -10
  14. sglang/srt/managers/io_struct.py +4 -0
  15. sglang/srt/managers/schedule_batch.py +120 -43
  16. sglang/srt/managers/schedule_policy.py +2 -1
  17. sglang/srt/managers/scheduler.py +202 -140
  18. sglang/srt/managers/tokenizer_manager.py +5 -1
  19. sglang/srt/managers/tp_worker.py +111 -1
  20. sglang/srt/mem_cache/chunk_cache.py +8 -4
  21. sglang/srt/mem_cache/memory_pool.py +77 -4
  22. sglang/srt/mem_cache/radix_cache.py +15 -7
  23. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  24. sglang/srt/model_executor/forward_batch_info.py +16 -21
  25. sglang/srt/model_executor/model_runner.py +60 -1
  26. sglang/srt/models/baichuan.py +2 -3
  27. sglang/srt/models/chatglm.py +5 -6
  28. sglang/srt/models/commandr.py +1 -2
  29. sglang/srt/models/dbrx.py +1 -2
  30. sglang/srt/models/deepseek.py +4 -5
  31. sglang/srt/models/deepseek_v2.py +5 -6
  32. sglang/srt/models/exaone.py +1 -2
  33. sglang/srt/models/gemma.py +2 -2
  34. sglang/srt/models/gemma2.py +5 -5
  35. sglang/srt/models/gpt_bigcode.py +5 -5
  36. sglang/srt/models/grok.py +1 -2
  37. sglang/srt/models/internlm2.py +1 -2
  38. sglang/srt/models/llama.py +1 -2
  39. sglang/srt/models/llama_classification.py +1 -2
  40. sglang/srt/models/llama_reward.py +2 -3
  41. sglang/srt/models/llava.py +4 -8
  42. sglang/srt/models/llavavid.py +1 -2
  43. sglang/srt/models/minicpm.py +1 -2
  44. sglang/srt/models/minicpm3.py +5 -6
  45. sglang/srt/models/mixtral.py +1 -2
  46. sglang/srt/models/mixtral_quant.py +1 -2
  47. sglang/srt/models/olmo.py +352 -0
  48. sglang/srt/models/olmoe.py +1 -2
  49. sglang/srt/models/qwen.py +1 -2
  50. sglang/srt/models/qwen2.py +1 -2
  51. sglang/srt/models/qwen2_moe.py +4 -5
  52. sglang/srt/models/stablelm.py +1 -2
  53. sglang/srt/models/torch_native_llama.py +1 -2
  54. sglang/srt/models/xverse.py +1 -2
  55. sglang/srt/models/xverse_moe.py +4 -5
  56. sglang/srt/models/yivl.py +1 -2
  57. sglang/srt/openai_api/adapter.py +92 -49
  58. sglang/srt/openai_api/protocol.py +10 -2
  59. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  60. sglang/srt/sampling/sampling_batch_info.py +92 -58
  61. sglang/srt/sampling/sampling_params.py +2 -0
  62. sglang/srt/server.py +116 -17
  63. sglang/srt/server_args.py +121 -45
  64. sglang/srt/utils.py +11 -3
  65. sglang/test/few_shot_gsm8k.py +4 -1
  66. sglang/test/few_shot_gsm8k_engine.py +144 -0
  67. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  68. sglang/version.py +1 -1
  69. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
  70. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
  71. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  72. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  73. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  74. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import (
26
25
  get_tensor_model_parallel_rank,
27
26
  get_tensor_model_parallel_world_size,
@@ -185,7 +184,7 @@ class DeepseekAttention(nn.Module):
185
184
  rope_theta: float = 10000,
186
185
  rope_scaling: Optional[Dict[str, Any]] = None,
187
186
  max_position_embeddings: int = 8192,
188
- cache_config: Optional[CacheConfig] = None,
187
+ cache_config=None,
189
188
  quant_config: Optional[QuantizationConfig] = None,
190
189
  ) -> None:
191
190
  super().__init__()
@@ -262,7 +261,7 @@ class DeepseekDecoderLayer(nn.Module):
262
261
  self,
263
262
  config: PretrainedConfig,
264
263
  layer_id: int,
265
- cache_config: Optional[CacheConfig] = None,
264
+ cache_config=None,
266
265
  quant_config: Optional[QuantizationConfig] = None,
267
266
  ) -> None:
268
267
  super().__init__()
@@ -331,7 +330,7 @@ class DeepseekModel(nn.Module):
331
330
  def __init__(
332
331
  self,
333
332
  config: PretrainedConfig,
334
- cache_config: Optional[CacheConfig] = None,
333
+ cache_config=None,
335
334
  quant_config: Optional[QuantizationConfig] = None,
336
335
  ) -> None:
337
336
  super().__init__()
@@ -374,7 +373,7 @@ class DeepseekForCausalLM(nn.Module):
374
373
  def __init__(
375
374
  self,
376
375
  config: PretrainedConfig,
377
- cache_config: Optional[CacheConfig] = None,
376
+ cache_config=None,
378
377
  quant_config: Optional[QuantizationConfig] = None,
379
378
  ) -> None:
380
379
  super().__init__()
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import (
26
25
  get_tensor_model_parallel_world_size,
27
26
  tensor_model_parallel_all_reduce,
@@ -188,7 +187,7 @@ class DeepseekV2Attention(nn.Module):
188
187
  rope_theta: float = 10000,
189
188
  rope_scaling: Optional[Dict[str, Any]] = None,
190
189
  max_position_embeddings: int = 8192,
191
- cache_config: Optional[CacheConfig] = None,
190
+ cache_config=None,
192
191
  quant_config: Optional[QuantizationConfig] = None,
193
192
  layer_id=None,
194
193
  ) -> None:
@@ -336,7 +335,7 @@ class DeepseekV2AttentionMLA(nn.Module):
336
335
  rope_theta: float = 10000,
337
336
  rope_scaling: Optional[Dict[str, Any]] = None,
338
337
  max_position_embeddings: int = 8192,
339
- cache_config: Optional[CacheConfig] = None,
338
+ cache_config=None,
340
339
  quant_config: Optional[QuantizationConfig] = None,
341
340
  layer_id=None,
342
341
  ) -> None:
@@ -498,7 +497,7 @@ class DeepseekV2DecoderLayer(nn.Module):
498
497
  self,
499
498
  config: PretrainedConfig,
500
499
  layer_id: int,
501
- cache_config: Optional[CacheConfig] = None,
500
+ cache_config=None,
502
501
  quant_config: Optional[QuantizationConfig] = None,
503
502
  ) -> None:
504
503
  super().__init__()
@@ -594,7 +593,7 @@ class DeepseekV2Model(nn.Module):
594
593
  def __init__(
595
594
  self,
596
595
  config: PretrainedConfig,
597
- cache_config: Optional[CacheConfig] = None,
596
+ cache_config=None,
598
597
  quant_config: Optional[QuantizationConfig] = None,
599
598
  ) -> None:
600
599
  super().__init__()
@@ -640,7 +639,7 @@ class DeepseekV2ForCausalLM(nn.Module):
640
639
  def __init__(
641
640
  self,
642
641
  config: PretrainedConfig,
643
- cache_config: Optional[CacheConfig] = None,
642
+ cache_config=None,
644
643
  quant_config: Optional[QuantizationConfig] = None,
645
644
  ) -> None:
646
645
  super().__init__()
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
26
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -295,7 +294,7 @@ class ExaoneForCausalLM(nn.Module):
295
294
  self,
296
295
  config,
297
296
  quant_config: Optional[QuantizationConfig] = None,
298
- cache_config: Optional[CacheConfig] = None,
297
+ cache_config=None,
299
298
  ) -> None:
300
299
  super().__init__()
301
300
  self.config = config
@@ -21,7 +21,7 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig, LoRAConfig
24
+ from vllm.config import LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
27
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -279,7 +279,7 @@ class GemmaForCausalLM(nn.Module):
279
279
  config: PretrainedConfig,
280
280
  quant_config: Optional[QuantizationConfig] = None,
281
281
  lora_config: Optional[LoRAConfig] = None,
282
- cache_config: Optional[CacheConfig] = None,
282
+ cache_config=None,
283
283
  ) -> None:
284
284
  del lora_config # Unused.
285
285
  super().__init__()
@@ -20,7 +20,7 @@ from typing import Iterable, Optional, Set, Tuple, Union
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import PretrainedConfig
23
- from vllm.config import CacheConfig, LoRAConfig
23
+ from vllm.config import LoRAConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
 
26
26
  # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
@@ -105,7 +105,7 @@ class Gemma2Attention(nn.Module):
105
105
  head_dim: int,
106
106
  max_position_embeddings: int,
107
107
  rope_theta: float,
108
- cache_config: Optional[CacheConfig] = None,
108
+ cache_config=None,
109
109
  quant_config: Optional[QuantizationConfig] = None,
110
110
  ) -> None:
111
111
  super().__init__()
@@ -190,7 +190,7 @@ class Gemma2DecoderLayer(nn.Module):
190
190
  self,
191
191
  layer_idx: int,
192
192
  config: PretrainedConfig,
193
- cache_config: Optional[CacheConfig] = None,
193
+ cache_config=None,
194
194
  quant_config: Optional[QuantizationConfig] = None,
195
195
  ) -> None:
196
196
  super().__init__()
@@ -257,7 +257,7 @@ class Gemma2Model(nn.Module):
257
257
  def __init__(
258
258
  self,
259
259
  config: PretrainedConfig,
260
- cache_config: Optional[CacheConfig] = None,
260
+ cache_config=None,
261
261
  quant_config: Optional[QuantizationConfig] = None,
262
262
  ) -> None:
263
263
  super().__init__()
@@ -336,7 +336,7 @@ class Gemma2ForCausalLM(nn.Module):
336
336
  def __init__(
337
337
  self,
338
338
  config: PretrainedConfig,
339
- cache_config: Optional[CacheConfig] = None,
339
+ cache_config=None,
340
340
  quant_config: Optional[QuantizationConfig] = None,
341
341
  lora_config: Optional[LoRAConfig] = None,
342
342
  ) -> None:
@@ -21,7 +21,7 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
- from vllm.config import CacheConfig, LoRAConfig
24
+ from vllm.config import LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
27
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -44,7 +44,7 @@ class GPTBigCodeAttention(nn.Module):
44
44
  self,
45
45
  layer_id: int,
46
46
  config: GPTBigCodeConfig,
47
- cache_config: Optional[CacheConfig] = None,
47
+ cache_config=None,
48
48
  quant_config: Optional[QuantizationConfig] = None,
49
49
  ):
50
50
  super().__init__()
@@ -145,7 +145,7 @@ class GPTBigCodeBlock(nn.Module):
145
145
  self,
146
146
  layer_id: int,
147
147
  config: GPTBigCodeConfig,
148
- cache_config: Optional[CacheConfig] = None,
148
+ cache_config=None,
149
149
  quant_config: Optional[QuantizationConfig] = None,
150
150
  ):
151
151
  super().__init__()
@@ -183,7 +183,7 @@ class GPTBigCodeModel(nn.Module):
183
183
  def __init__(
184
184
  self,
185
185
  config: GPTBigCodeConfig,
186
- cache_config: Optional[CacheConfig] = None,
186
+ cache_config=None,
187
187
  quant_config: Optional[QuantizationConfig] = None,
188
188
  lora_config: Optional[LoRAConfig] = None,
189
189
  ):
@@ -243,7 +243,7 @@ class GPTBigCodeForCausalLM(nn.Module):
243
243
  def __init__(
244
244
  self,
245
245
  config: GPTBigCodeConfig,
246
- cache_config: Optional[CacheConfig] = None,
246
+ cache_config=None,
247
247
  quant_config: Optional[QuantizationConfig] = None,
248
248
  lora_config: Optional[LoRAConfig] = None,
249
249
  ):
sglang/srt/models/grok.py CHANGED
@@ -23,7 +23,6 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import PretrainedConfig
26
- from vllm.config import CacheConfig
27
26
  from vllm.distributed import (
28
27
  get_tensor_model_parallel_rank,
29
28
  get_tensor_model_parallel_world_size,
@@ -289,7 +288,7 @@ class Grok1ForCausalLM(nn.Module):
289
288
  self,
290
289
  config: PretrainedConfig,
291
290
  quant_config: Optional[QuantizationConfig] = None,
292
- cache_config: Optional[CacheConfig] = None,
291
+ cache_config=None,
293
292
  ) -> None:
294
293
  super().__init__()
295
294
  self.config = config
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
26
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -254,7 +253,7 @@ class InternLM2ForCausalLM(nn.Module):
254
253
  self,
255
254
  config: PretrainedConfig,
256
255
  quant_config: Optional[QuantizationConfig] = None,
257
- cache_config: Optional[CacheConfig] = None,
256
+ cache_config=None,
258
257
  ) -> None:
259
258
  super().__init__()
260
259
  self.config = config
@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import LlamaConfig
25
- from vllm.config import CacheConfig
26
25
  from vllm.distributed import get_tensor_model_parallel_world_size
27
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
27
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -295,7 +294,7 @@ class LlamaForCausalLM(nn.Module):
295
294
  self,
296
295
  config: LlamaConfig,
297
296
  quant_config: Optional[QuantizationConfig] = None,
298
- cache_config: Optional[CacheConfig] = None,
297
+ cache_config=None,
299
298
  ) -> None:
300
299
  super().__init__()
301
300
  self.config = config
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import LlamaConfig
21
- from vllm.config import CacheConfig
22
21
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
22
 
24
23
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -32,7 +31,7 @@ class LlamaForClassification(nn.Module):
32
31
  self,
33
32
  config: LlamaConfig,
34
33
  quant_config: Optional[QuantizationConfig] = None,
35
- cache_config: Optional[CacheConfig] = None,
34
+ cache_config=None,
36
35
  ) -> None:
37
36
  super().__init__()
38
37
  self.config = config
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import LlamaConfig
21
- from vllm.config import CacheConfig
22
21
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
22
 
24
23
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -33,7 +32,7 @@ class LlamaForSequenceClassification(nn.Module):
33
32
  self,
34
33
  config: LlamaConfig,
35
34
  quant_config: Optional[QuantizationConfig] = None,
36
- cache_config: Optional[CacheConfig] = None,
35
+ cache_config=None,
37
36
  ) -> None:
38
37
  super().__init__()
39
38
  self.config = config
@@ -92,7 +91,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
92
91
  self,
93
92
  config: LlamaConfig,
94
93
  quant_config: Optional[QuantizationConfig] = None,
95
- cache_config: Optional[CacheConfig] = None,
94
+ cache_config=None,
96
95
  ) -> None:
97
96
  super().__init__(config, quant_config, cache_config)
98
97
  self.weights = self.Weights(config.hidden_size, self.num_labels)
@@ -31,7 +31,6 @@ from transformers import (
31
31
  SiglipVisionModel,
32
32
  )
33
33
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
34
- from vllm.config import CacheConfig
35
34
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
35
 
37
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -161,9 +160,6 @@ class LlavaBaseForCausalLM(nn.Module):
161
160
  image_sizes = [
162
161
  image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
163
162
  ]
164
- image_offsets = [
165
- image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
166
- ]
167
163
 
168
164
  ########## Encode Image ########
169
165
 
@@ -359,7 +355,7 @@ class LlavaBaseForCausalLM(nn.Module):
359
355
  prefix_len = prefix_lens_cpu[i]
360
356
 
361
357
  # Multiple images
362
- for j, image_offset in enumerate(image_offsets[i]):
358
+ for j, image_offset in enumerate(image_inputs[i].image_offsets):
363
359
  if image_offset < prefix_len:
364
360
  continue
365
361
 
@@ -450,7 +446,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
450
446
  self,
451
447
  config: LlavaConfig,
452
448
  quant_config: Optional[QuantizationConfig] = None,
453
- cache_config: Optional[CacheConfig] = None,
449
+ cache_config=None,
454
450
  ) -> None:
455
451
  super().__init__()
456
452
 
@@ -472,7 +468,7 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
472
468
  self,
473
469
  config: LlavaConfig,
474
470
  quant_config: Optional[QuantizationConfig] = None,
475
- cache_config: Optional[CacheConfig] = None,
471
+ cache_config=None,
476
472
  ) -> None:
477
473
  super().__init__()
478
474
 
@@ -505,7 +501,7 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
505
501
  self,
506
502
  config: LlavaConfig,
507
503
  quant_config: Optional[QuantizationConfig] = None,
508
- cache_config: Optional[CacheConfig] = None,
504
+ cache_config=None,
509
505
  ) -> None:
510
506
  super().__init__()
511
507
 
@@ -22,7 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from transformers import CLIPVisionModel, LlavaConfig
24
24
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
25
- from vllm.config import CacheConfig
26
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
26
 
28
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -36,7 +35,7 @@ class LlavaVidForCausalLM(nn.Module):
36
35
  self,
37
36
  config: LlavaConfig,
38
37
  quant_config: Optional[QuantizationConfig] = None,
39
- cache_config: Optional[CacheConfig] = None,
38
+ cache_config=None,
40
39
  ) -> None:
41
40
  super().__init__()
42
41
  self.config = config
@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  from torch import nn
23
- from vllm.config import CacheConfig
24
23
  from vllm.distributed import get_tensor_model_parallel_world_size
25
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
25
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -278,7 +277,7 @@ class MiniCPMForCausalLM(nn.Module):
278
277
  self,
279
278
  config,
280
279
  quant_config: Optional[QuantizationConfig] = None,
281
- cache_config: Optional[CacheConfig] = None,
280
+ cache_config=None,
282
281
  ) -> None:
283
282
  super().__init__()
284
283
  self.config = config
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.linear import (
27
26
  ColumnParallelLinear,
@@ -108,7 +107,7 @@ class MiniCPM3Attention(nn.Module):
108
107
  rope_theta: float = 10000,
109
108
  rope_scaling: Optional[Dict[str, Any]] = None,
110
109
  max_position_embeddings: int = 8192,
111
- cache_config: Optional[CacheConfig] = None,
110
+ cache_config=None,
112
111
  quant_config: Optional[QuantizationConfig] = None,
113
112
  layer_id=None,
114
113
  ) -> None:
@@ -252,7 +251,7 @@ class MiniCPM3AttentionMLA(nn.Module):
252
251
  rope_theta: float = 10000,
253
252
  rope_scaling: Optional[Dict[str, Any]] = None,
254
253
  max_position_embeddings: int = 8192,
255
- cache_config: Optional[CacheConfig] = None,
254
+ cache_config=None,
256
255
  quant_config: Optional[QuantizationConfig] = None,
257
256
  layer_id=None,
258
257
  ) -> None:
@@ -409,7 +408,7 @@ class MiniCPM3DecoderLayer(nn.Module):
409
408
  self,
410
409
  config: PretrainedConfig,
411
410
  layer_id: int,
412
- cache_config: Optional[CacheConfig] = None,
411
+ cache_config=None,
413
412
  quant_config: Optional[QuantizationConfig] = None,
414
413
  ) -> None:
415
414
  super().__init__()
@@ -501,7 +500,7 @@ class MiniCPM3Model(nn.Module):
501
500
  def __init__(
502
501
  self,
503
502
  config: PretrainedConfig,
504
- cache_config: Optional[CacheConfig] = None,
503
+ cache_config=None,
505
504
  quant_config: Optional[QuantizationConfig] = None,
506
505
  ) -> None:
507
506
  super().__init__()
@@ -552,7 +551,7 @@ class MiniCPM3ForCausalLM(nn.Module):
552
551
  def __init__(
553
552
  self,
554
553
  config: PretrainedConfig,
555
- cache_config: Optional[CacheConfig] = None,
554
+ cache_config=None,
556
555
  quant_config: Optional[QuantizationConfig] = None,
557
556
  ) -> None:
558
557
  super().__init__()
@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import MixtralConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.fused_moe import FusedMoE
27
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -293,7 +292,7 @@ class MixtralForCausalLM(nn.Module):
293
292
  self,
294
293
  config: MixtralConfig,
295
294
  quant_config: Optional[QuantizationConfig] = None,
296
- cache_config: Optional[CacheConfig] = None,
295
+ cache_config=None,
297
296
  ) -> None:
298
297
  super().__init__()
299
298
  self.config = config
@@ -23,7 +23,6 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import MixtralConfig
26
- from vllm.config import CacheConfig
27
26
  from vllm.distributed import (
28
27
  get_tensor_model_parallel_rank,
29
28
  get_tensor_model_parallel_world_size,
@@ -325,7 +324,7 @@ class QuantMixtralForCausalLM(nn.Module):
325
324
  self,
326
325
  config: MixtralConfig,
327
326
  quant_config: Optional[QuantizationConfig] = None,
328
- cache_config: Optional[CacheConfig] = None,
327
+ cache_config=None,
329
328
  ) -> None:
330
329
  super().__init__()
331
330
  self.config = config