sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
sglang/srt/models/minicpm.py
CHANGED
@@ -20,7 +20,6 @@ import torch
|
|
20
20
|
from torch import nn
|
21
21
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
22
22
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
23
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
23
|
|
25
24
|
from sglang.srt.layers.activation import SiluAndMul
|
26
25
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
37
36
|
VocabParallelEmbedding,
|
38
37
|
)
|
39
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
40
40
|
|
41
41
|
|
42
42
|
class MiniCPMMLP(nn.Module):
|
@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
|
|
275
275
|
self,
|
276
276
|
config,
|
277
277
|
quant_config: Optional[QuantizationConfig] = None,
|
278
|
-
cache_config=None,
|
279
278
|
) -> None:
|
280
279
|
super().__init__()
|
281
280
|
self.config = config
|
@@ -308,12 +307,10 @@ class MiniCPMForCausalLM(nn.Module):
|
|
308
307
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
309
308
|
hidden_states = hidden_states / self.scale_width
|
310
309
|
if self.config.tie_word_embeddings:
|
311
|
-
|
310
|
+
lm_head = self.model.embed_tokens
|
312
311
|
else:
|
313
|
-
|
314
|
-
return self.logits_processor(
|
315
|
-
input_ids, hidden_states, lm_head_weight, forward_batch
|
316
|
-
)
|
312
|
+
lm_head = self.lm_head
|
313
|
+
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
317
314
|
|
318
315
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
319
316
|
stacked_params_mapping = [
|
sglang/srt/models/minicpm3.py
CHANGED
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
|
|
27
27
|
RowParallelLinear,
|
28
28
|
)
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
30
|
|
32
31
|
from sglang.srt.layers.activation import SiluAndMul
|
33
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
40
39
|
)
|
41
40
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
43
|
from sglang.srt.utils import is_flashinfer_available
|
44
44
|
|
45
45
|
if is_flashinfer_available():
|
@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
|
|
105
105
|
rope_theta: float = 10000,
|
106
106
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
107
107
|
max_position_embeddings: int = 8192,
|
108
|
-
cache_config=None,
|
109
108
|
quant_config: Optional[QuantizationConfig] = None,
|
110
109
|
layer_id=None,
|
111
110
|
) -> None:
|
@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
249
248
|
rope_theta: float = 10000,
|
250
249
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
251
250
|
max_position_embeddings: int = 8192,
|
252
|
-
cache_config=None,
|
253
251
|
quant_config: Optional[QuantizationConfig] = None,
|
254
252
|
layer_id=None,
|
255
253
|
) -> None:
|
@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
406
404
|
self,
|
407
405
|
config: PretrainedConfig,
|
408
406
|
layer_id: int,
|
409
|
-
cache_config=None,
|
410
407
|
quant_config: Optional[QuantizationConfig] = None,
|
411
408
|
) -> None:
|
412
409
|
super().__init__()
|
@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
430
427
|
rope_theta=rope_theta,
|
431
428
|
rope_scaling=rope_scaling,
|
432
429
|
max_position_embeddings=max_position_embeddings,
|
433
|
-
cache_config=cache_config,
|
434
430
|
quant_config=quant_config,
|
435
431
|
layer_id=layer_id,
|
436
432
|
)
|
@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
449
445
|
rope_theta=rope_theta,
|
450
446
|
rope_scaling=rope_scaling,
|
451
447
|
max_position_embeddings=max_position_embeddings,
|
452
|
-
cache_config=cache_config,
|
453
448
|
quant_config=quant_config,
|
454
449
|
layer_id=layer_id,
|
455
450
|
)
|
@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
|
|
498
493
|
def __init__(
|
499
494
|
self,
|
500
495
|
config: PretrainedConfig,
|
501
|
-
cache_config=None,
|
502
496
|
quant_config: Optional[QuantizationConfig] = None,
|
503
497
|
) -> None:
|
504
498
|
super().__init__()
|
@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
|
|
512
506
|
)
|
513
507
|
self.layers = nn.ModuleList(
|
514
508
|
[
|
515
|
-
MiniCPM3DecoderLayer(
|
516
|
-
config, i, cache_config=cache_config, quant_config=quant_config
|
517
|
-
)
|
509
|
+
MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
|
518
510
|
for i in range(config.num_hidden_layers)
|
519
511
|
]
|
520
512
|
)
|
@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
549
541
|
def __init__(
|
550
542
|
self,
|
551
543
|
config: PretrainedConfig,
|
552
|
-
cache_config=None,
|
553
544
|
quant_config: Optional[QuantizationConfig] = None,
|
554
545
|
) -> None:
|
555
546
|
super().__init__()
|
@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
557
548
|
|
558
549
|
self.num_experts = getattr(self.config, "num_experts", 0)
|
559
550
|
self.quant_config = quant_config
|
560
|
-
self.model = MiniCPM3Model(
|
561
|
-
config, cache_config=cache_config, quant_config=quant_config
|
562
|
-
)
|
551
|
+
self.model = MiniCPM3Model(config, quant_config=quant_config)
|
563
552
|
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
564
553
|
if not self.config.tie_word_embeddings:
|
565
554
|
self.lm_head = ParallelLMHead(
|
@@ -585,12 +574,10 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
585
574
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
586
575
|
hidden_states = hidden_states / self.scale_width
|
587
576
|
if self.config.tie_word_embeddings:
|
588
|
-
|
577
|
+
lm_head = self.model.embed_tokens
|
589
578
|
else:
|
590
|
-
|
591
|
-
return self.logits_processor(
|
592
|
-
input_ids, hidden_states, lm_head_weight, forward_batch
|
593
|
-
)
|
579
|
+
lm_head = self.lm_head
|
580
|
+
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
594
581
|
|
595
582
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
596
583
|
stacked_params_mapping = [
|
sglang/srt/models/mixtral.py
CHANGED
@@ -23,7 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import MixtralConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
25
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
26
|
|
28
27
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
29
28
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
42
41
|
)
|
43
42
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
44
43
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
44
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
45
45
|
|
46
46
|
|
47
47
|
class MixtralMoE(nn.Module):
|
@@ -291,7 +291,6 @@ class MixtralForCausalLM(nn.Module):
|
|
291
291
|
self,
|
292
292
|
config: MixtralConfig,
|
293
293
|
quant_config: Optional[QuantizationConfig] = None,
|
294
|
-
cache_config=None,
|
295
294
|
) -> None:
|
296
295
|
super().__init__()
|
297
296
|
self.config = config
|
@@ -310,7 +309,7 @@ class MixtralForCausalLM(nn.Module):
|
|
310
309
|
) -> torch.Tensor:
|
311
310
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
312
311
|
return self.logits_processor(
|
313
|
-
input_ids, hidden_states, self.lm_head
|
312
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
314
313
|
)
|
315
314
|
|
316
315
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -340,7 +339,9 @@ class MixtralForCausalLM(nn.Module):
|
|
340
339
|
continue
|
341
340
|
name = name.replace(weight_name, param_name)
|
342
341
|
# Skip loading extra bias for GPTQ models.
|
343
|
-
if
|
342
|
+
if (
|
343
|
+
name.endswith(".bias") or name.endswith("_bias")
|
344
|
+
) and name not in params_dict:
|
344
345
|
continue
|
345
346
|
|
346
347
|
param = params_dict[name]
|
@@ -354,6 +355,10 @@ class MixtralForCausalLM(nn.Module):
|
|
354
355
|
continue
|
355
356
|
name = name.replace(weight_name, param_name)
|
356
357
|
|
358
|
+
if (
|
359
|
+
name.endswith(".bias") or name.endswith("_bias")
|
360
|
+
) and name not in params_dict:
|
361
|
+
continue
|
357
362
|
param = params_dict[name]
|
358
363
|
weight_loader = param.weight_loader
|
359
364
|
weight_loader(
|
@@ -366,7 +371,9 @@ class MixtralForCausalLM(nn.Module):
|
|
366
371
|
break
|
367
372
|
else:
|
368
373
|
# Skip loading extra bias for GPTQ models.
|
369
|
-
if
|
374
|
+
if (
|
375
|
+
name.endswith(".bias") or name.endswith("_bias")
|
376
|
+
) and name not in params_dict:
|
370
377
|
continue
|
371
378
|
# Skip loading kv_scale from ckpts towards new design.
|
372
379
|
if name.endswith(".kv_scale") and name not in params_dict:
|
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
tensor_model_parallel_all_reduce,
|
30
30
|
)
|
31
31
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
32
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
33
32
|
|
34
33
|
from sglang.srt.layers.layernorm import RMSNorm
|
35
34
|
from sglang.srt.layers.linear import (
|
@@ -45,6 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
45
44
|
VocabParallelEmbedding,
|
46
45
|
)
|
47
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
48
48
|
|
49
49
|
|
50
50
|
class MixtralMLP(nn.Module):
|
@@ -324,7 +324,6 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
324
324
|
self,
|
325
325
|
config: MixtralConfig,
|
326
326
|
quant_config: Optional[QuantizationConfig] = None,
|
327
|
-
cache_config=None,
|
328
327
|
) -> None:
|
329
328
|
super().__init__()
|
330
329
|
self.config = config
|
@@ -343,7 +342,7 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
343
342
|
) -> torch.Tensor:
|
344
343
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
345
344
|
return self.logits_processor(
|
346
|
-
input_ids, hidden_states, self.lm_head
|
345
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
347
346
|
)
|
348
347
|
|
349
348
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/mllama.py
CHANGED
@@ -15,7 +15,6 @@ from transformers.models.mllama.modeling_mllama import (
|
|
15
15
|
_prepare_aspect_ratio_attention_mask,
|
16
16
|
)
|
17
17
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
18
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
19
18
|
|
20
19
|
from sglang.srt.layers.activation import get_act_fn
|
21
20
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -34,6 +33,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
34
33
|
)
|
35
34
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
36
35
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
36
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
37
37
|
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
|
38
38
|
|
39
39
|
|
@@ -654,7 +654,6 @@ class MllamaTextModel(nn.Module):
|
|
654
654
|
self,
|
655
655
|
config: config_mllama.MllamaTextConfig,
|
656
656
|
quant_config: Optional[QuantizationConfig],
|
657
|
-
cache_config=None,
|
658
657
|
):
|
659
658
|
super().__init__()
|
660
659
|
self.padding_id = config.pad_token_id
|
@@ -732,11 +731,10 @@ class MllamaForCausalLM(nn.Module):
|
|
732
731
|
self,
|
733
732
|
config: config_mllama.MllamaTextConfig,
|
734
733
|
quant_config: Optional[QuantizationConfig],
|
735
|
-
cache_config=None,
|
736
734
|
):
|
737
735
|
super().__init__()
|
738
736
|
self.vocab_size = config.vocab_size
|
739
|
-
self.model = MllamaTextModel(config,
|
737
|
+
self.model = MllamaTextModel(config, quant_config)
|
740
738
|
self.lm_head = ParallelLMHead(
|
741
739
|
config.vocab_size,
|
742
740
|
config.hidden_size,
|
@@ -772,7 +770,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
772
770
|
self,
|
773
771
|
config: config_mllama.MllamaConfig,
|
774
772
|
quant_config: Optional[QuantizationConfig] = None,
|
775
|
-
cache_config=None,
|
776
773
|
):
|
777
774
|
super().__init__()
|
778
775
|
self.vocab_size = config.text_config.vocab_size
|
@@ -787,7 +784,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
787
784
|
self.vision_model = MllamaVisionModel(config.vision_config)
|
788
785
|
self.language_model = MllamaForCausalLM(
|
789
786
|
config.text_config,
|
790
|
-
cache_config=cache_config,
|
791
787
|
quant_config=quant_config,
|
792
788
|
)
|
793
789
|
self.multi_modal_projector = nn.Linear(
|
@@ -966,7 +962,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
966
962
|
skip_cross_attention=skip_cross_attention,
|
967
963
|
)
|
968
964
|
return self.logits_processor(
|
969
|
-
input_ids, hidden_states, self.language_model.lm_head
|
965
|
+
input_ids, hidden_states, self.language_model.lm_head, forward_batch
|
970
966
|
)
|
971
967
|
|
972
968
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/olmo.py
CHANGED
@@ -22,7 +22,6 @@ from torch import nn
|
|
22
22
|
from transformers import OlmoConfig
|
23
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
25
|
|
27
26
|
from sglang.srt.layers.activation import SiluAndMul
|
28
27
|
from sglang.srt.layers.linear import (
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
38
37
|
VocabParallelEmbedding,
|
39
38
|
)
|
40
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
41
|
from sglang.srt.utils import make_layers
|
42
42
|
|
43
43
|
|
@@ -274,7 +274,6 @@ class OlmoForCausalLM(nn.Module):
|
|
274
274
|
def __init__(
|
275
275
|
self,
|
276
276
|
config: OlmoConfig,
|
277
|
-
cache_config=None,
|
278
277
|
quant_config: Optional[QuantizationConfig] = None,
|
279
278
|
):
|
280
279
|
super().__init__()
|
@@ -306,7 +305,7 @@ class OlmoForCausalLM(nn.Module):
|
|
306
305
|
input_embeds=input_embeds,
|
307
306
|
)
|
308
307
|
return self.logits_processor(
|
309
|
-
input_ids, hidden_states, self.lm_head
|
308
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
310
309
|
)
|
311
310
|
|
312
311
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -326,11 +325,6 @@ class OlmoForCausalLM(nn.Module):
|
|
326
325
|
# Models trained using ColossalAI may include these tensors in
|
327
326
|
# the checkpoint. Skip them.
|
328
327
|
continue
|
329
|
-
# With tie_word_embeddings, we can skip lm_head.weight
|
330
|
-
# The weight might appear unnecessarily in the files if the model is
|
331
|
-
# processed with quantization, LoRA, fine-tuning, etc.
|
332
|
-
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
333
|
-
continue
|
334
328
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
335
329
|
if weight_name not in name:
|
336
330
|
continue
|