sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +79 -53
  3. sglang/bench_serving.py +186 -14
  4. sglang/profiler.py +0 -1
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/longcat_flash.py +104 -0
  7. sglang/srt/configs/model_config.py +12 -0
  8. sglang/srt/connector/__init__.py +1 -1
  9. sglang/srt/connector/base_connector.py +1 -2
  10. sglang/srt/connector/redis.py +2 -2
  11. sglang/srt/connector/serde/__init__.py +1 -1
  12. sglang/srt/connector/serde/safe_serde.py +4 -3
  13. sglang/srt/conversation.py +38 -5
  14. sglang/srt/disaggregation/ascend/conn.py +75 -0
  15. sglang/srt/disaggregation/launch_lb.py +0 -13
  16. sglang/srt/disaggregation/mini_lb.py +33 -8
  17. sglang/srt/disaggregation/prefill.py +1 -1
  18. sglang/srt/distributed/parallel_state.py +24 -14
  19. sglang/srt/entrypoints/engine.py +19 -12
  20. sglang/srt/entrypoints/http_server.py +174 -34
  21. sglang/srt/entrypoints/openai/protocol.py +87 -24
  22. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  23. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  24. sglang/srt/eplb/eplb_manager.py +26 -2
  25. sglang/srt/eplb/expert_distribution.py +29 -2
  26. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  27. sglang/srt/function_call/function_call_parser.py +2 -0
  28. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  29. sglang/srt/harmony_parser.py +588 -0
  30. sglang/srt/hf_transformers_utils.py +26 -7
  31. sglang/srt/layers/activation.py +12 -0
  32. sglang/srt/layers/attention/ascend_backend.py +374 -136
  33. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  34. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  35. sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
  36. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  38. sglang/srt/layers/communicator.py +1 -2
  39. sglang/srt/layers/layernorm.py +28 -3
  40. sglang/srt/layers/linear.py +3 -2
  41. sglang/srt/layers/logits_processor.py +1 -1
  42. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  43. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  44. sglang/srt/layers/moe/ep_moe/layer.py +13 -13
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/topk.py +35 -12
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  49. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  50. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  51. sglang/srt/layers/quantization/fp8.py +2 -1
  52. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  53. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  54. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  55. sglang/srt/layers/quantization/mxfp4.py +25 -27
  56. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  57. sglang/srt/layers/quantization/utils.py +13 -0
  58. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  59. sglang/srt/layers/rotary_embedding.py +28 -1
  60. sglang/srt/layers/sampler.py +29 -5
  61. sglang/srt/layers/utils.py +0 -14
  62. sglang/srt/managers/cache_controller.py +237 -204
  63. sglang/srt/managers/detokenizer_manager.py +48 -2
  64. sglang/srt/managers/io_struct.py +57 -0
  65. sglang/srt/managers/mm_utils.py +5 -1
  66. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  67. sglang/srt/managers/scheduler.py +94 -9
  68. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  69. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  70. sglang/srt/managers/tokenizer_manager.py +122 -42
  71. sglang/srt/mem_cache/chunk_cache.py +1 -1
  72. sglang/srt/mem_cache/hicache_storage.py +51 -23
  73. sglang/srt/mem_cache/hiradix_cache.py +87 -71
  74. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  75. sglang/srt/mem_cache/memory_pool.py +77 -14
  76. sglang/srt/mem_cache/memory_pool_host.py +4 -5
  77. sglang/srt/mem_cache/radix_cache.py +6 -4
  78. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  79. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
  80. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
  81. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  82. sglang/srt/model_executor/model_runner.py +6 -5
  83. sglang/srt/model_loader/loader.py +15 -24
  84. sglang/srt/model_loader/utils.py +12 -0
  85. sglang/srt/models/deepseek_v2.py +38 -13
  86. sglang/srt/models/gpt_oss.py +2 -15
  87. sglang/srt/models/llama_eagle3.py +4 -0
  88. sglang/srt/models/longcat_flash.py +1015 -0
  89. sglang/srt/models/longcat_flash_nextn.py +691 -0
  90. sglang/srt/models/qwen2.py +26 -3
  91. sglang/srt/models/qwen2_5_vl.py +66 -41
  92. sglang/srt/models/qwen2_moe.py +22 -2
  93. sglang/srt/models/transformers.py +1 -1
  94. sglang/srt/multimodal/processors/base_processor.py +4 -2
  95. sglang/srt/reasoning_parser.py +56 -300
  96. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  97. sglang/srt/server_args.py +122 -56
  98. sglang/srt/speculative/eagle_worker.py +28 -8
  99. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  100. sglang/srt/utils.py +73 -5
  101. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  102. sglang/version.py +1 -1
  103. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
  104. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
  105. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  106. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
16
16
  # Modify details for the adaptation of Qwen2 model.
17
17
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
18
18
  import logging
19
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
19
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20
20
 
21
21
  import torch
22
22
  from torch import nn
@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module):
431
431
  quant_config=quant_config,
432
432
  prefix=add_prefix("lm_head", prefix),
433
433
  )
434
-
435
434
  else:
436
435
  # ranks other than the last rank will have a placeholder layer
437
436
  self.lm_head = PPMissingLayer()
@@ -452,6 +451,8 @@ class Qwen2ForCausalLM(nn.Module):
452
451
 
453
452
  self.logits_processor = LogitsProcessor(config)
454
453
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
454
+ # For EAGLE3 support
455
+ self.capture_aux_hidden_states = False
455
456
 
456
457
  def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
457
458
  return self.model.get_input_embedding(input_ids)
@@ -476,11 +477,18 @@ class Qwen2ForCausalLM(nn.Module):
476
477
  input_embeds,
477
478
  pp_proxy_tensors=pp_proxy_tensors,
478
479
  )
480
+ aux_hidden_states = None
481
+ if self.capture_aux_hidden_states:
482
+ hidden_states, aux_hidden_states = hidden_states
479
483
 
480
484
  if self.pp_group.is_last_rank:
481
485
  if not get_embedding:
482
486
  return self.logits_processor(
483
- input_ids, hidden_states, self.lm_head, forward_batch
487
+ input_ids,
488
+ hidden_states,
489
+ self.lm_head,
490
+ forward_batch,
491
+ aux_hidden_states,
484
492
  )
485
493
  else:
486
494
  return self.pooler(hidden_states, forward_batch)
@@ -619,5 +627,20 @@ class Qwen2ForCausalLM(nn.Module):
619
627
  def load_kv_cache_scales(self, quantization_param_path: str) -> None:
620
628
  self.model.load_kv_cache_scales(quantization_param_path)
621
629
 
630
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
631
+ if not self.pp_group.is_last_rank:
632
+ return
633
+
634
+ self.capture_aux_hidden_states = True
635
+ if layer_ids is None:
636
+ num_layers = self.config.num_hidden_layers
637
+ self.model.layers_to_capture = [
638
+ 2,
639
+ num_layers // 2,
640
+ num_layers - 3,
641
+ ] # Specific layers for EAGLE3 support
642
+ else:
643
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
644
+
622
645
 
623
646
  EntryClass = Qwen2ForCausalLM
@@ -31,7 +31,6 @@ import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange
33
33
  from transformers.activations import ACT2FN
34
- from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
35
34
  from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
36
35
  Qwen2_5_VLConfig,
37
36
  Qwen2_5_VLVisionConfig,
@@ -43,7 +42,12 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
43
42
 
44
43
  from sglang.srt.hf_transformers_utils import get_processor
45
44
  from sglang.srt.layers.attention.vision import VisionAttention
46
- from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
+ from sglang.srt.layers.linear import (
47
+ ColumnParallelLinear,
48
+ MergedColumnParallelLinear,
49
+ RowParallelLinear,
50
+ )
47
51
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
52
  from sglang.srt.layers.pooler import Pooler, PoolingType
49
53
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
62
66
 
63
67
 
64
68
  class Qwen2_5_VLMLP(nn.Module):
65
-
66
69
  def __init__(
67
70
  self,
68
71
  in_features: int,
@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
73
76
  prefix: str = "",
74
77
  ):
75
78
  super().__init__()
76
- self.gate_proj = ColumnParallelLinear(
77
- in_features,
78
- hidden_features,
79
- bias=bias,
80
- quant_config=quant_config,
81
- prefix=add_prefix("gate_proj", prefix),
82
- )
83
- self.up_proj = ColumnParallelLinear(
84
- in_features,
85
- hidden_features,
79
+ self.gate_up_proj = MergedColumnParallelLinear(
80
+ input_size=in_features,
81
+ output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
86
82
  bias=bias,
87
83
  quant_config=quant_config,
88
- prefix=add_prefix("up_proj", prefix),
84
+ prefix=add_prefix("gate_up_proj", prefix),
89
85
  )
90
86
  self.down_proj = RowParallelLinear(
91
87
  hidden_features,
@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
97
93
  self.act = ACT2FN[hidden_act]
98
94
 
99
95
  def forward(self, x: torch.Tensor) -> torch.Tensor:
100
- x_parallel_gate, _ = self.gate_proj(x)
101
- x_parallel_gate = self.act(x_parallel_gate)
102
- x_parallel_up, _ = self.up_proj(x)
103
- x_parallel = x_parallel_gate * x_parallel_up
104
- x, _ = self.down_proj(x_parallel)
105
- return x
96
+ gate_up, _ = self.gate_up_proj(x)
97
+ gate, up = gate_up.chunk(2, dim=-1)
98
+ x = self.act(gate) * up
99
+ x_down, _ = self.down_proj(x)
100
+ return x_down
106
101
 
107
102
 
108
103
  class Qwen2_5_VisionBlock(nn.Module):
@@ -122,8 +117,8 @@ class Qwen2_5_VisionBlock(nn.Module):
122
117
  super().__init__()
123
118
  if norm_layer is None:
124
119
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
125
- self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
126
- self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
120
+ self.norm1 = RMSNorm(dim, eps=1e-6)
121
+ self.norm2 = RMSNorm(dim, eps=1e-6)
127
122
 
128
123
  if attn_implementation is None:
129
124
  softmax_in_single_precision = False
@@ -174,18 +169,29 @@ class Qwen2_5_VisionBlock(nn.Module):
174
169
  cu_seqlens: torch.Tensor,
175
170
  position_embeddings: torch.Tensor,
176
171
  ) -> torch.Tensor:
177
- hidden_states = self.norm1(x)
178
- hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
172
+ S, B, H = x.shape
173
+ # norm1: flatten to 2D -> [S*B, H], then reshape back
174
+ x2d = x.reshape(-1, H)
175
+ hidden_states = self.norm1(x2d).reshape(S, B, H)
176
+
177
+ # Attention expects [B, S, H]
178
+ hidden_states = rearrange(hidden_states, "s b h -> b s h")
179
179
  attn = self.attn(
180
180
  hidden_states,
181
181
  cu_seqlens=cu_seqlens,
182
182
  position_embeddings=position_embeddings,
183
183
  )
184
- attn = rearrange(attn, "b s ... -> s b ...")
185
- x = x + attn
186
- norm2 = self.norm2(x)
187
- mlp = self.mlp(norm2)
188
- x = x + mlp
184
+ attn = rearrange(attn, "b s h -> s b h")
185
+
186
+ # norm2 with fused residual-add: also 2D
187
+ attn2d = attn.reshape(-1, H)
188
+ x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
189
+ x_norm = x_norm_2d.reshape(S, B, H)
190
+ x_after_add = x_after_add_2d.reshape(S, B, H)
191
+
192
+ # MLP and final residual
193
+ mlp_out = self.mlp(x_norm)
194
+ x = x_after_add + mlp_out
189
195
  return x
190
196
 
191
197
 
@@ -201,7 +207,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
201
207
  ) -> None:
202
208
  super().__init__()
203
209
  self.hidden_size = context_dim * (spatial_merge_size**2)
204
- self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
210
+ self.ln_q = RMSNorm(context_dim, eps=1e-6)
205
211
  self.mlp = nn.ModuleList(
206
212
  [
207
213
  ColumnParallelLinear(
@@ -223,11 +229,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
223
229
  )
224
230
 
225
231
  def forward(self, x: torch.Tensor) -> torch.Tensor:
226
- x = self.ln_q(x)
227
- x = x.view(-1, self.hidden_size)
228
-
232
+ # x expected shape: [S, B, context_dim]
233
+ S, B, D = x.shape
234
+ x2d = x.reshape(-1, D)
235
+ x2d = self.ln_q(x2d) # RMSNorm expects 2D
236
+ x2d = x2d.view(-1, self.hidden_size) # group into spatial_merge_unit
229
237
  mlp_fc1, mlp_act, mlp_fc2 = self.mlp
230
- x_parallel, _ = mlp_fc1(x)
238
+ x_parallel, _ = mlp_fc1(x2d)
231
239
  x_parallel = mlp_act(x_parallel)
232
240
  out, _ = mlp_fc2(x_parallel)
233
241
  return out
@@ -340,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
340
348
 
341
349
  @property
342
350
  def device(self) -> torch.device:
343
- return self.blocks[0].mlp.gate_proj.weight.device
351
+ return self.patch_embed.proj.weight.device
344
352
 
345
353
  def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
346
354
  pos_ids = []
@@ -394,6 +402,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
394
402
  )
395
403
  cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
396
404
 
405
+ # Move window_index to the same device as x before using it to index x
406
+ window_index = window_index.to(device=x.device)
407
+
408
+ # Ensure rotary_pos_emb is on the same device/dtype as x
409
+ rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
410
+
397
411
  seq_len, _ = x.size()
398
412
 
399
413
  x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
@@ -406,12 +420,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
406
420
  rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
407
421
  emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
408
422
  position_embeddings = (emb.cos(), emb.sin())
423
+ # After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
424
+ position_embeddings = (
425
+ position_embeddings[0].to(x.device, x.dtype),
426
+ position_embeddings[1].to(x.device, x.dtype),
427
+ )
409
428
 
410
- # compute cu_seqlens
429
+ # compute cu_seqlens - move cu_seqlens to GPU and make it int32
411
430
  cu_seqlens = torch.cat(
412
431
  [
413
- torch.tensor([0], device=grid_thw.device),
414
- (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
432
+ torch.tensor([0], device=x.device, dtype=torch.int32),
433
+ (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
434
+ .cumsum(dim=0)
435
+ .to(device=x.device, dtype=torch.int32),
415
436
  ]
416
437
  )
417
438
  cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
@@ -442,9 +463,8 @@ cached_get_processor = lru_cache(get_processor)
442
463
  class Qwen2_5_VLForConditionalGeneration(nn.Module):
443
464
  # BitandBytes specific attributes
444
465
  default_bitsandbytes_target_modules = [
445
- ".gate_proj.",
466
+ ".gate_up_proj.",
446
467
  ".down_proj.",
447
- ".up_proj.",
448
468
  ".q_proj.",
449
469
  ".k_proj.",
450
470
  ".v_proj.",
@@ -526,6 +546,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
526
546
  def get_input_embeddings(self):
527
547
  return self.model.embed_tokens
528
548
 
549
+ @torch.no_grad()
529
550
  def forward(
530
551
  self,
531
552
  input_ids: torch.Tensor,
@@ -590,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
590
611
  for param_name, weight_name, shard_id in stacked_params_mapping:
591
612
  if weight_name not in name:
592
613
  continue
593
- if "visual" in name:
614
+ if (
615
+ "visual" in name
616
+ and "up_proj" not in name
617
+ and "gate_proj" not in name
618
+ ):
594
619
  continue
595
620
  name = name.replace(weight_name, param_name)
596
621
 
@@ -17,7 +17,7 @@
17
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
20
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
21
21
 
22
22
  import torch
23
23
  import torch.nn.functional as F
@@ -536,6 +536,8 @@ class Qwen2MoeForCausalLM(nn.Module):
536
536
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
537
537
  )
538
538
  self.logits_processor = LogitsProcessor(config)
539
+ # For EAGLE3 support
540
+ self.capture_aux_hidden_states = False
539
541
 
540
542
  @torch.no_grad()
541
543
  def forward(
@@ -553,9 +555,12 @@ class Qwen2MoeForCausalLM(nn.Module):
553
555
  input_embeds,
554
556
  pp_proxy_tensors=pp_proxy_tensors,
555
557
  )
558
+ aux_hidden_states = None
559
+ if self.capture_aux_hidden_states:
560
+ hidden_states, aux_hidden_states = hidden_states
556
561
  if self.pp_group.is_last_rank:
557
562
  return self.logits_processor(
558
- input_ids, hidden_states, self.lm_head, forward_batch
563
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
559
564
  )
560
565
  else:
561
566
  return hidden_states
@@ -705,5 +710,20 @@ class Qwen2MoeForCausalLM(nn.Module):
705
710
  num_groups=None,
706
711
  )
707
712
 
713
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
714
+ if not self.pp_group.is_last_rank:
715
+ return
716
+
717
+ self.capture_aux_hidden_states = True
718
+ if layer_ids is None:
719
+ num_layers = self.config.num_hidden_layers
720
+ self.model.layers_to_capture = [
721
+ 2,
722
+ num_layers // 2,
723
+ num_layers - 3,
724
+ ] # Specific layers for EAGLE3 support
725
+ else:
726
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
727
+
708
728
 
709
729
  EntryClass = Qwen2MoeForCausalLM
@@ -213,7 +213,7 @@ class TransformersForCausalLM(nn.Module):
213
213
  """
214
214
  tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
215
215
 
216
- if not tp_plan and self.tp_size > 1:
216
+ if not tp_plan and tp_size > 1:
217
217
  raise ValueError(
218
218
  f"{type(self.model)} does not support tensor parallel yet!"
219
219
  )
@@ -13,7 +13,9 @@ from PIL import Image
13
13
  from transformers import BaseImageProcessorFast
14
14
 
15
15
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
16
- from sglang.srt.utils import load_audio, load_image, load_video, logger
16
+ from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger
17
+
18
+ _is_npu = is_npu()
17
19
 
18
20
 
19
21
  @dataclasses.dataclass
@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
232
234
  and isinstance(processor.image_processor, BaseImageProcessorFast)
233
235
  and not self.server_args.disable_fast_image_processor
234
236
  ):
235
- kwargs["device"] = "cuda"
237
+ kwargs["device"] = "cuda" if not _is_npu else "npu"
236
238
  result = processor.__call__(
237
239
  text=[input_text],
238
240
  padding=True,