sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
|
32
32
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
33
33
|
from sglang.srt.distributed import (
|
34
34
|
get_tp_group,
|
35
|
+
get_world_group,
|
35
36
|
init_distributed_environment,
|
36
37
|
initialize_model_parallel,
|
37
38
|
set_custom_all_reduce,
|
@@ -173,6 +174,7 @@ class ModelRunner:
|
|
173
174
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
174
175
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
175
176
|
"use_mla_backend": self.use_mla_backend,
|
177
|
+
"mm_attention_backend": server_args.mm_attention_backend,
|
176
178
|
}
|
177
179
|
)
|
178
180
|
|
@@ -278,9 +280,10 @@ class ModelRunner:
|
|
278
280
|
server_args.attention_backend = "fa3"
|
279
281
|
else:
|
280
282
|
server_args.attention_backend = "triton"
|
281
|
-
|
282
|
-
|
283
|
-
|
283
|
+
if self.should_log:
|
284
|
+
logger.info(
|
285
|
+
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
286
|
+
)
|
284
287
|
elif self.use_mla_backend:
|
285
288
|
if server_args.device != "cpu":
|
286
289
|
if server_args.attention_backend in [
|
@@ -290,9 +293,10 @@ class ModelRunner:
|
|
290
293
|
"flashmla",
|
291
294
|
"cutlass_mla",
|
292
295
|
]:
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
+
if self.should_log:
|
297
|
+
logger.info(
|
298
|
+
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
299
|
+
)
|
296
300
|
else:
|
297
301
|
raise ValueError(
|
298
302
|
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
@@ -311,9 +315,10 @@ class ModelRunner:
|
|
311
315
|
server_args.attention_backend = "triton"
|
312
316
|
|
313
317
|
if server_args.enable_double_sparsity:
|
314
|
-
|
315
|
-
|
316
|
-
|
318
|
+
if self.should_log:
|
319
|
+
logger.info(
|
320
|
+
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
321
|
+
)
|
317
322
|
server_args.attention_backend = "triton"
|
318
323
|
server_args.disable_cuda_graph = True
|
319
324
|
if server_args.ds_heavy_channel_type is None:
|
@@ -324,23 +329,26 @@ class ModelRunner:
|
|
324
329
|
|
325
330
|
if self.is_multimodal:
|
326
331
|
self.mem_fraction_static *= 0.90
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
332
|
+
if self.should_log:
|
333
|
+
logger.info(
|
334
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
335
|
+
f"because this is a multimodal model."
|
336
|
+
)
|
337
|
+
logger.info(
|
338
|
+
"Automatically turn off --chunked-prefill-size for multimodal model."
|
339
|
+
)
|
334
340
|
server_args.chunked_prefill_size = -1
|
335
341
|
|
336
342
|
if not self.use_mla_backend:
|
337
343
|
server_args.disable_chunked_prefix_cache = True
|
338
344
|
elif self.page_size > 1:
|
339
|
-
|
345
|
+
if self.should_log:
|
346
|
+
logger.info("Disable chunked prefix cache when page size > 1.")
|
340
347
|
server_args.disable_chunked_prefix_cache = True
|
341
348
|
|
342
349
|
if not server_args.disable_chunked_prefix_cache:
|
343
|
-
|
350
|
+
if self.should_log:
|
351
|
+
logger.info("Chunked prefix cache is turned on.")
|
344
352
|
|
345
353
|
def init_torch_distributed(self):
|
346
354
|
logger.info("Init torch distributed begin.")
|
@@ -361,6 +369,8 @@ class ModelRunner:
|
|
361
369
|
backend = "hccl"
|
362
370
|
elif self.device == "cpu":
|
363
371
|
backend = "gloo"
|
372
|
+
elif self.device == "npu":
|
373
|
+
backend = "hccl"
|
364
374
|
|
365
375
|
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
366
376
|
if not self.server_args.enable_p2p_check:
|
@@ -391,11 +401,15 @@ class ModelRunner:
|
|
391
401
|
tp_rank=self.tp_rank,
|
392
402
|
tp_size=self.tp_size,
|
393
403
|
dp_size=self.server_args.dp_size,
|
404
|
+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
394
405
|
pp_size=self.server_args.pp_size,
|
395
406
|
)
|
396
407
|
|
397
408
|
min_per_gpu_memory = get_available_gpu_memory(
|
398
|
-
self.device,
|
409
|
+
self.device,
|
410
|
+
self.gpu_id,
|
411
|
+
distributed=get_world_group().world_size > 1,
|
412
|
+
cpu_group=get_world_group().cpu_group,
|
399
413
|
)
|
400
414
|
self.tp_group = get_tp_group()
|
401
415
|
self.attention_tp_group = get_attention_tp_group()
|
@@ -431,9 +445,10 @@ class ModelRunner:
|
|
431
445
|
torch.set_num_threads(1)
|
432
446
|
if self.device == "cuda":
|
433
447
|
if torch.cuda.get_device_capability()[0] < 8:
|
434
|
-
|
435
|
-
|
436
|
-
|
448
|
+
if self.should_log:
|
449
|
+
logger.info(
|
450
|
+
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
451
|
+
)
|
437
452
|
self.server_args.dtype = "float16"
|
438
453
|
self.model_config.dtype = torch.float16
|
439
454
|
if torch.cuda.get_device_capability()[1] < 5:
|
@@ -469,10 +484,11 @@ class ModelRunner:
|
|
469
484
|
self.model.load_kv_cache_scales(
|
470
485
|
self.server_args.quantization_param_path
|
471
486
|
)
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
487
|
+
if self.should_log:
|
488
|
+
logger.info(
|
489
|
+
"Loaded KV cache scaling factors from %s",
|
490
|
+
self.server_args.quantization_param_path,
|
491
|
+
)
|
476
492
|
else:
|
477
493
|
raise RuntimeError(
|
478
494
|
"Using FP8 KV cache and scaling factors provided but "
|
@@ -547,12 +563,7 @@ class ModelRunner:
|
|
547
563
|
return iter
|
548
564
|
|
549
565
|
def model_load_weights(model, iter):
|
550
|
-
|
551
|
-
for _, module in self.model.named_modules():
|
552
|
-
quant_method = getattr(module, "quant_method", None)
|
553
|
-
if quant_method is not None:
|
554
|
-
with device_loading_context(module, target_device):
|
555
|
-
quant_method.process_weights_after_loading(module)
|
566
|
+
DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
|
556
567
|
return model
|
557
568
|
|
558
569
|
with set_default_torch_dtype(self.model_config.dtype):
|
@@ -710,7 +721,10 @@ class ModelRunner:
|
|
710
721
|
|
711
722
|
def profile_max_num_token(self, total_gpu_memory: int):
|
712
723
|
available_gpu_memory = get_available_gpu_memory(
|
713
|
-
self.device,
|
724
|
+
self.device,
|
725
|
+
self.gpu_id,
|
726
|
+
distributed=get_world_group().world_size > 1,
|
727
|
+
cpu_group=get_world_group().cpu_group,
|
714
728
|
)
|
715
729
|
if self.use_mla_backend:
|
716
730
|
num_layers = (
|
@@ -1019,7 +1033,8 @@ class ModelRunner:
|
|
1019
1033
|
)
|
1020
1034
|
|
1021
1035
|
def apply_torch_tp(self):
|
1022
|
-
|
1036
|
+
if self.should_log:
|
1037
|
+
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
1023
1038
|
from sglang.srt.model_parallel import tensor_parallel
|
1024
1039
|
|
1025
1040
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
@@ -1078,32 +1093,33 @@ class ModelRunner:
|
|
1078
1093
|
forward_batch: ForwardBatch,
|
1079
1094
|
skip_attn_backend_init: bool = False,
|
1080
1095
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
1081
|
-
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
1096
|
+
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1082
1097
|
can_run_cuda_graph = bool(
|
1083
1098
|
forward_batch.forward_mode.is_cuda_graph()
|
1084
1099
|
and self.cuda_graph_runner
|
1085
1100
|
and self.cuda_graph_runner.can_run(forward_batch)
|
1086
1101
|
)
|
1087
1102
|
if can_run_cuda_graph:
|
1088
|
-
|
1103
|
+
ret = self.cuda_graph_runner.replay(
|
1089
1104
|
forward_batch,
|
1090
1105
|
skip_attn_backend_init=skip_attn_backend_init,
|
1091
1106
|
pp_proxy_tensors=pp_proxy_tensors,
|
1092
1107
|
)
|
1093
|
-
|
1094
|
-
|
1095
|
-
return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1108
|
+
elif forward_batch.forward_mode.is_decode():
|
1109
|
+
ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1096
1110
|
elif forward_batch.forward_mode.is_extend():
|
1097
|
-
|
1111
|
+
ret = self.forward_extend(
|
1098
1112
|
forward_batch,
|
1099
1113
|
skip_attn_backend_init=skip_attn_backend_init,
|
1100
1114
|
pp_proxy_tensors=pp_proxy_tensors,
|
1101
1115
|
)
|
1102
1116
|
elif forward_batch.forward_mode.is_idle():
|
1103
|
-
|
1117
|
+
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
1104
1118
|
else:
|
1105
1119
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1106
1120
|
|
1121
|
+
return ret, can_run_cuda_graph
|
1122
|
+
|
1107
1123
|
def _preprocess_logits(
|
1108
1124
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
1109
1125
|
):
|
@@ -374,20 +374,27 @@ class DefaultModelLoader(BaseModelLoader):
|
|
374
374
|
self.load_config,
|
375
375
|
)
|
376
376
|
|
377
|
-
|
377
|
+
self.load_weights_and_postprocess(
|
378
|
+
model, self._get_all_weights(model_config, model), target_device
|
379
|
+
)
|
378
380
|
|
379
|
-
for _, module in model.named_modules():
|
380
|
-
quant_method = getattr(module, "quant_method", None)
|
381
|
-
if quant_method is not None:
|
382
|
-
# When quant methods need to process weights after loading
|
383
|
-
# (for repacking, quantizing, etc), they expect parameters
|
384
|
-
# to be on the global target device. This scope is for the
|
385
|
-
# case where cpu offloading is used, where we will move the
|
386
|
-
# parameters onto device for processing and back off after.
|
387
|
-
with device_loading_context(module, target_device):
|
388
|
-
quant_method.process_weights_after_loading(module)
|
389
381
|
return model.eval()
|
390
382
|
|
383
|
+
@staticmethod
|
384
|
+
def load_weights_and_postprocess(model, weights, target_device):
|
385
|
+
model.load_weights(weights)
|
386
|
+
|
387
|
+
for _, module in model.named_modules():
|
388
|
+
quant_method = getattr(module, "quant_method", None)
|
389
|
+
if quant_method is not None:
|
390
|
+
# When quant methods need to process weights after loading
|
391
|
+
# (for repacking, quantizing, etc), they expect parameters
|
392
|
+
# to be on the global target device. This scope is for the
|
393
|
+
# case where cpu offloading is used, where we will move the
|
394
|
+
# parameters onto device for processing and back off after.
|
395
|
+
with device_loading_context(module, target_device):
|
396
|
+
quant_method.process_weights_after_loading(module)
|
397
|
+
|
391
398
|
|
392
399
|
class LayeredModelLoader(DefaultModelLoader):
|
393
400
|
"""Model loader that loads weights layer by layer so that one can quantize a
|
sglang/srt/models/clip.py
CHANGED
@@ -151,20 +151,20 @@ class CLIPEncoderLayer(nn.Module):
|
|
151
151
|
self.layer_norm1 = norm_layer(config.hidden_size)
|
152
152
|
self.layer_norm2 = norm_layer(config.hidden_size)
|
153
153
|
if attn_implementation == "sdpa":
|
154
|
-
|
154
|
+
qkv_backend = "sdpa"
|
155
155
|
softmax_in_single_precision = False
|
156
156
|
elif attn_implementation == "flash_attention_2":
|
157
|
+
qkv_backend = "triton_attn"
|
157
158
|
softmax_in_single_precision = False
|
158
|
-
use_context_forward = True
|
159
159
|
elif attn_implementation == "eager":
|
160
|
+
qkv_backend = "sdpa"
|
160
161
|
softmax_in_single_precision = True
|
161
|
-
use_context_forward = False
|
162
162
|
self.self_attn = VisionAttention(
|
163
163
|
embed_dim=config.hidden_size,
|
164
164
|
num_heads=config.num_attention_heads,
|
165
165
|
projection_size=config.hidden_size,
|
166
166
|
use_qkv_parallel=True,
|
167
|
-
|
167
|
+
qkv_backend=qkv_backend,
|
168
168
|
softmax_in_single_precision=softmax_in_single_precision,
|
169
169
|
flatten_batch=True,
|
170
170
|
quant_config=quant_config,
|
@@ -188,7 +188,7 @@ def trunc_normal_tf_(
|
|
188
188
|
best when :math:`a \\leq \text{mean} \\leq b`.
|
189
189
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
190
190
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
191
|
-
and the result is
|
191
|
+
and the result is subsequently scaled and shifted by the mean and std args.
|
192
192
|
Args:
|
193
193
|
tensor: an n-dimensional `torch.Tensor`
|
194
194
|
mean: the mean of the normal distribution
|
@@ -532,7 +532,7 @@ class VisionTransformerBlock(nn.Module):
|
|
532
532
|
num_heads=num_heads,
|
533
533
|
projection_size=dim,
|
534
534
|
use_qkv_parallel=True,
|
535
|
-
|
535
|
+
qkv_backend="sdpa",
|
536
536
|
softmax_in_single_precision=False,
|
537
537
|
dropout=attn_drop,
|
538
538
|
)
|
@@ -735,7 +735,7 @@ class VisionTransformer(nn.Module):
|
|
735
735
|
img_size: Input image size.
|
736
736
|
patch_size: Patch size.
|
737
737
|
in_chans: Number of image input channels.
|
738
|
-
num_classes:
|
738
|
+
num_classes: Number of classes for classification head.
|
739
739
|
global_pool: Type of global pooling for final sequence (default: 'token').
|
740
740
|
embed_dim: Transformer embedding dimension.
|
741
741
|
depth: Depth of transformer.
|
@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
24
24
|
from sglang.srt.layers.layernorm import RMSNorm
|
25
25
|
from sglang.srt.layers.linear import ReplicatedLinear
|
26
26
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
28
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
29
27
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
30
|
-
from sglang.srt.layers.quantization.fp8_utils import (
|
31
|
-
block_quant_to_tensor_quant,
|
32
|
-
normalize_e4m3fn_to_e4m3fnuz,
|
33
|
-
)
|
34
|
-
from sglang.srt.layers.quantization.int8_utils import (
|
35
|
-
block_dequant as int8_block_dequant,
|
36
|
-
)
|
37
28
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
38
29
|
ParallelLMHead,
|
39
30
|
VocabParallelEmbedding,
|
40
31
|
)
|
41
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
33
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
|
-
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
44
34
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
45
|
-
from sglang.srt.utils import BumpAllocator, add_prefix
|
46
|
-
|
47
|
-
_is_hip = is_hip()
|
48
|
-
_is_cuda = is_cuda()
|
49
|
-
|
50
|
-
if _is_cuda:
|
51
|
-
from sgl_kernel import awq_dequantize
|
52
|
-
else:
|
53
|
-
from vllm._custom_ops import awq_dequantize
|
54
|
-
|
35
|
+
from sglang.srt.utils import BumpAllocator, add_prefix
|
55
36
|
|
56
37
|
logger = logging.getLogger(__name__)
|
57
38
|
|