sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
32
32
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
33
  from sglang.srt.distributed import (
34
34
  get_tp_group,
35
+ get_world_group,
35
36
  init_distributed_environment,
36
37
  initialize_model_parallel,
37
38
  set_custom_all_reduce,
@@ -173,6 +174,7 @@ class ModelRunner:
173
174
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
174
175
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
175
176
  "use_mla_backend": self.use_mla_backend,
177
+ "mm_attention_backend": server_args.mm_attention_backend,
176
178
  }
177
179
  )
178
180
 
@@ -278,9 +280,10 @@ class ModelRunner:
278
280
  server_args.attention_backend = "fa3"
279
281
  else:
280
282
  server_args.attention_backend = "triton"
281
- logger.info(
282
- f"Attention backend not set. Use {server_args.attention_backend} backend by default."
283
- )
283
+ if self.should_log:
284
+ logger.info(
285
+ f"Attention backend not set. Use {server_args.attention_backend} backend by default."
286
+ )
284
287
  elif self.use_mla_backend:
285
288
  if server_args.device != "cpu":
286
289
  if server_args.attention_backend in [
@@ -290,9 +293,10 @@ class ModelRunner:
290
293
  "flashmla",
291
294
  "cutlass_mla",
292
295
  ]:
293
- logger.info(
294
- f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
295
- )
296
+ if self.should_log:
297
+ logger.info(
298
+ f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
299
+ )
296
300
  else:
297
301
  raise ValueError(
298
302
  f"Invalid attention backend for MLA: {server_args.attention_backend}"
@@ -311,9 +315,10 @@ class ModelRunner:
311
315
  server_args.attention_backend = "triton"
312
316
 
313
317
  if server_args.enable_double_sparsity:
314
- logger.info(
315
- "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
316
- )
318
+ if self.should_log:
319
+ logger.info(
320
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
321
+ )
317
322
  server_args.attention_backend = "triton"
318
323
  server_args.disable_cuda_graph = True
319
324
  if server_args.ds_heavy_channel_type is None:
@@ -324,23 +329,26 @@ class ModelRunner:
324
329
 
325
330
  if self.is_multimodal:
326
331
  self.mem_fraction_static *= 0.90
327
- logger.info(
328
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
329
- f"because this is a multimodal model."
330
- )
331
- logger.info(
332
- "Automatically turn off --chunked-prefill-size for multimodal model."
333
- )
332
+ if self.should_log:
333
+ logger.info(
334
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
335
+ f"because this is a multimodal model."
336
+ )
337
+ logger.info(
338
+ "Automatically turn off --chunked-prefill-size for multimodal model."
339
+ )
334
340
  server_args.chunked_prefill_size = -1
335
341
 
336
342
  if not self.use_mla_backend:
337
343
  server_args.disable_chunked_prefix_cache = True
338
344
  elif self.page_size > 1:
339
- logger.info("Disable chunked prefix cache when page size > 1.")
345
+ if self.should_log:
346
+ logger.info("Disable chunked prefix cache when page size > 1.")
340
347
  server_args.disable_chunked_prefix_cache = True
341
348
 
342
349
  if not server_args.disable_chunked_prefix_cache:
343
- logger.info("Chunked prefix cache is turned on.")
350
+ if self.should_log:
351
+ logger.info("Chunked prefix cache is turned on.")
344
352
 
345
353
  def init_torch_distributed(self):
346
354
  logger.info("Init torch distributed begin.")
@@ -361,6 +369,8 @@ class ModelRunner:
361
369
  backend = "hccl"
362
370
  elif self.device == "cpu":
363
371
  backend = "gloo"
372
+ elif self.device == "npu":
373
+ backend = "hccl"
364
374
 
365
375
  before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
366
376
  if not self.server_args.enable_p2p_check:
@@ -391,11 +401,15 @@ class ModelRunner:
391
401
  tp_rank=self.tp_rank,
392
402
  tp_size=self.tp_size,
393
403
  dp_size=self.server_args.dp_size,
404
+ moe_dense_tp_size=self.server_args.moe_dense_tp_size,
394
405
  pp_size=self.server_args.pp_size,
395
406
  )
396
407
 
397
408
  min_per_gpu_memory = get_available_gpu_memory(
398
- self.device, self.gpu_id, distributed=self.tp_size > 1
409
+ self.device,
410
+ self.gpu_id,
411
+ distributed=get_world_group().world_size > 1,
412
+ cpu_group=get_world_group().cpu_group,
399
413
  )
400
414
  self.tp_group = get_tp_group()
401
415
  self.attention_tp_group = get_attention_tp_group()
@@ -431,9 +445,10 @@ class ModelRunner:
431
445
  torch.set_num_threads(1)
432
446
  if self.device == "cuda":
433
447
  if torch.cuda.get_device_capability()[0] < 8:
434
- logger.info(
435
- "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
436
- )
448
+ if self.should_log:
449
+ logger.info(
450
+ "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
451
+ )
437
452
  self.server_args.dtype = "float16"
438
453
  self.model_config.dtype = torch.float16
439
454
  if torch.cuda.get_device_capability()[1] < 5:
@@ -469,10 +484,11 @@ class ModelRunner:
469
484
  self.model.load_kv_cache_scales(
470
485
  self.server_args.quantization_param_path
471
486
  )
472
- logger.info(
473
- "Loaded KV cache scaling factors from %s",
474
- self.server_args.quantization_param_path,
475
- )
487
+ if self.should_log:
488
+ logger.info(
489
+ "Loaded KV cache scaling factors from %s",
490
+ self.server_args.quantization_param_path,
491
+ )
476
492
  else:
477
493
  raise RuntimeError(
478
494
  "Using FP8 KV cache and scaling factors provided but "
@@ -547,12 +563,7 @@ class ModelRunner:
547
563
  return iter
548
564
 
549
565
  def model_load_weights(model, iter):
550
- model.load_weights(iter)
551
- for _, module in self.model.named_modules():
552
- quant_method = getattr(module, "quant_method", None)
553
- if quant_method is not None:
554
- with device_loading_context(module, target_device):
555
- quant_method.process_weights_after_loading(module)
566
+ DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
556
567
  return model
557
568
 
558
569
  with set_default_torch_dtype(self.model_config.dtype):
@@ -710,7 +721,10 @@ class ModelRunner:
710
721
 
711
722
  def profile_max_num_token(self, total_gpu_memory: int):
712
723
  available_gpu_memory = get_available_gpu_memory(
713
- self.device, self.gpu_id, distributed=self.tp_size > 1
724
+ self.device,
725
+ self.gpu_id,
726
+ distributed=get_world_group().world_size > 1,
727
+ cpu_group=get_world_group().cpu_group,
714
728
  )
715
729
  if self.use_mla_backend:
716
730
  num_layers = (
@@ -1019,7 +1033,8 @@ class ModelRunner:
1019
1033
  )
1020
1034
 
1021
1035
  def apply_torch_tp(self):
1022
- logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1036
+ if self.should_log:
1037
+ logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1023
1038
  from sglang.srt.model_parallel import tensor_parallel
1024
1039
 
1025
1040
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
@@ -1078,32 +1093,33 @@ class ModelRunner:
1078
1093
  forward_batch: ForwardBatch,
1079
1094
  skip_attn_backend_init: bool = False,
1080
1095
  pp_proxy_tensors: Optional[PPProxyTensors] = None,
1081
- ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
1096
+ ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1082
1097
  can_run_cuda_graph = bool(
1083
1098
  forward_batch.forward_mode.is_cuda_graph()
1084
1099
  and self.cuda_graph_runner
1085
1100
  and self.cuda_graph_runner.can_run(forward_batch)
1086
1101
  )
1087
1102
  if can_run_cuda_graph:
1088
- return self.cuda_graph_runner.replay(
1103
+ ret = self.cuda_graph_runner.replay(
1089
1104
  forward_batch,
1090
1105
  skip_attn_backend_init=skip_attn_backend_init,
1091
1106
  pp_proxy_tensors=pp_proxy_tensors,
1092
1107
  )
1093
-
1094
- if forward_batch.forward_mode.is_decode():
1095
- return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1108
+ elif forward_batch.forward_mode.is_decode():
1109
+ ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1096
1110
  elif forward_batch.forward_mode.is_extend():
1097
- return self.forward_extend(
1111
+ ret = self.forward_extend(
1098
1112
  forward_batch,
1099
1113
  skip_attn_backend_init=skip_attn_backend_init,
1100
1114
  pp_proxy_tensors=pp_proxy_tensors,
1101
1115
  )
1102
1116
  elif forward_batch.forward_mode.is_idle():
1103
- return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1117
+ ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1104
1118
  else:
1105
1119
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1106
1120
 
1121
+ return ret, can_run_cuda_graph
1122
+
1107
1123
  def _preprocess_logits(
1108
1124
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
1109
1125
  ):
@@ -374,20 +374,27 @@ class DefaultModelLoader(BaseModelLoader):
374
374
  self.load_config,
375
375
  )
376
376
 
377
- model.load_weights(self._get_all_weights(model_config, model))
377
+ self.load_weights_and_postprocess(
378
+ model, self._get_all_weights(model_config, model), target_device
379
+ )
378
380
 
379
- for _, module in model.named_modules():
380
- quant_method = getattr(module, "quant_method", None)
381
- if quant_method is not None:
382
- # When quant methods need to process weights after loading
383
- # (for repacking, quantizing, etc), they expect parameters
384
- # to be on the global target device. This scope is for the
385
- # case where cpu offloading is used, where we will move the
386
- # parameters onto device for processing and back off after.
387
- with device_loading_context(module, target_device):
388
- quant_method.process_weights_after_loading(module)
389
381
  return model.eval()
390
382
 
383
+ @staticmethod
384
+ def load_weights_and_postprocess(model, weights, target_device):
385
+ model.load_weights(weights)
386
+
387
+ for _, module in model.named_modules():
388
+ quant_method = getattr(module, "quant_method", None)
389
+ if quant_method is not None:
390
+ # When quant methods need to process weights after loading
391
+ # (for repacking, quantizing, etc), they expect parameters
392
+ # to be on the global target device. This scope is for the
393
+ # case where cpu offloading is used, where we will move the
394
+ # parameters onto device for processing and back off after.
395
+ with device_loading_context(module, target_device):
396
+ quant_method.process_weights_after_loading(module)
397
+
391
398
 
392
399
  class LayeredModelLoader(DefaultModelLoader):
393
400
  """Model loader that loads weights layer by layer so that one can quantize a
sglang/srt/models/clip.py CHANGED
@@ -151,20 +151,20 @@ class CLIPEncoderLayer(nn.Module):
151
151
  self.layer_norm1 = norm_layer(config.hidden_size)
152
152
  self.layer_norm2 = norm_layer(config.hidden_size)
153
153
  if attn_implementation == "sdpa":
154
- use_context_forward = False
154
+ qkv_backend = "sdpa"
155
155
  softmax_in_single_precision = False
156
156
  elif attn_implementation == "flash_attention_2":
157
+ qkv_backend = "triton_attn"
157
158
  softmax_in_single_precision = False
158
- use_context_forward = True
159
159
  elif attn_implementation == "eager":
160
+ qkv_backend = "sdpa"
160
161
  softmax_in_single_precision = True
161
- use_context_forward = False
162
162
  self.self_attn = VisionAttention(
163
163
  embed_dim=config.hidden_size,
164
164
  num_heads=config.num_attention_heads,
165
165
  projection_size=config.hidden_size,
166
166
  use_qkv_parallel=True,
167
- use_context_forward=use_context_forward,
167
+ qkv_backend=qkv_backend,
168
168
  softmax_in_single_precision=softmax_in_single_precision,
169
169
  flatten_batch=True,
170
170
  quant_config=quant_config,
@@ -188,7 +188,7 @@ def trunc_normal_tf_(
188
188
  best when :math:`a \\leq \text{mean} \\leq b`.
189
189
  NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
190
190
  bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
191
- and the result is subsquently scaled and shifted by the mean and std args.
191
+ and the result is subsequently scaled and shifted by the mean and std args.
192
192
  Args:
193
193
  tensor: an n-dimensional `torch.Tensor`
194
194
  mean: the mean of the normal distribution
@@ -532,7 +532,7 @@ class VisionTransformerBlock(nn.Module):
532
532
  num_heads=num_heads,
533
533
  projection_size=dim,
534
534
  use_qkv_parallel=True,
535
- use_context_forward=False,
535
+ qkv_backend="sdpa",
536
536
  softmax_in_single_precision=False,
537
537
  dropout=attn_drop,
538
538
  )
@@ -735,7 +735,7 @@ class VisionTransformer(nn.Module):
735
735
  img_size: Input image size.
736
736
  patch_size: Patch size.
737
737
  in_chans: Number of image input channels.
738
- num_classes: Mumber of classes for classification head.
738
+ num_classes: Number of classes for classification head.
739
739
  global_pool: Type of global pooling for final sequence (default: 'token').
740
740
  embed_dim: Transformer embedding dimension.
741
741
  depth: Depth of transformer.
@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.layers.layernorm import RMSNorm
25
25
  from sglang.srt.layers.linear import ReplicatedLinear
26
26
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
28
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
29
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
- from sglang.srt.layers.quantization.fp8_utils import (
31
- block_quant_to_tensor_quant,
32
- normalize_e4m3fn_to_e4m3fnuz,
33
- )
34
- from sglang.srt.layers.quantization.int8_utils import (
35
- block_dequant as int8_block_dequant,
36
- )
37
28
  from sglang.srt.layers.vocab_parallel_embedding import (
38
29
  ParallelLMHead,
39
30
  VocabParallelEmbedding,
40
31
  )
41
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
33
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
- from sglang.srt.model_loader.weight_utils import default_weight_loader
44
34
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
45
- from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
46
-
47
- _is_hip = is_hip()
48
- _is_cuda = is_cuda()
49
-
50
- if _is_cuda:
51
- from sgl_kernel import awq_dequantize
52
- else:
53
- from vllm._custom_ops import awq_dequantize
54
-
35
+ from sglang.srt.utils import BumpAllocator, add_prefix
55
36
 
56
37
  logger = logging.getLogger(__name__)
57
38