sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import importlib.resources
|
|
21
21
|
import logging
|
22
22
|
import pkgutil
|
23
23
|
from functools import lru_cache
|
24
|
-
from typing import Optional,
|
24
|
+
from typing import Optional, Type
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
@@ -38,20 +38,23 @@ from vllm.model_executor.model_loader import get_model
|
|
38
38
|
from vllm.model_executor.models import ModelRegistry
|
39
39
|
|
40
40
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
41
|
-
from sglang.srt.
|
41
|
+
from sglang.srt.constrained import disable_cache
|
42
|
+
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
43
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
42
44
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
43
45
|
from sglang.srt.layers.sampler import Sampler
|
44
46
|
from sglang.srt.lora.lora_manager import LoRAManager
|
45
|
-
from sglang.srt.managers.schedule_batch import
|
47
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
48
|
from sglang.srt.mem_cache.memory_pool import (
|
47
49
|
MHATokenToKVPool,
|
48
50
|
MLATokenToKVPool,
|
49
51
|
ReqToTokenPool,
|
50
52
|
)
|
51
|
-
from sglang.srt.model_executor.forward_batch_info import
|
53
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
54
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
53
55
|
from sglang.srt.server_args import ServerArgs
|
54
56
|
from sglang.srt.utils import (
|
57
|
+
enable_show_time_cost,
|
55
58
|
get_available_gpu_memory,
|
56
59
|
is_generation_model,
|
57
60
|
is_multimodal_model,
|
@@ -87,6 +90,7 @@ class ModelRunner:
|
|
87
90
|
self.model_config.hf_config.architectures
|
88
91
|
)
|
89
92
|
|
93
|
+
# Model-specific adjustment
|
90
94
|
if (
|
91
95
|
self.model_config.attention_arch == AttentionArch.MLA
|
92
96
|
and not self.server_args.disable_mla
|
@@ -94,6 +98,19 @@ class ModelRunner:
|
|
94
98
|
logger.info("MLA optimization is tunred on. Use triton backend.")
|
95
99
|
self.server_args.attention_backend = "triton"
|
96
100
|
|
101
|
+
if self.is_multimodal_model:
|
102
|
+
logger.info(
|
103
|
+
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
104
|
+
)
|
105
|
+
server_args.chunked_prefill_size = None
|
106
|
+
server_args.mem_fraction_static *= 0.95
|
107
|
+
|
108
|
+
# Global vars
|
109
|
+
if server_args.show_time_cost:
|
110
|
+
enable_show_time_cost()
|
111
|
+
if server_args.disable_disk_cache:
|
112
|
+
disable_cache()
|
113
|
+
|
97
114
|
global_server_args_dict.update(
|
98
115
|
{
|
99
116
|
"attention_backend": server_args.attention_backend,
|
@@ -104,14 +121,6 @@ class ModelRunner:
|
|
104
121
|
}
|
105
122
|
)
|
106
123
|
|
107
|
-
# Model-specific adjustment
|
108
|
-
if self.is_multimodal_model:
|
109
|
-
logger.info(
|
110
|
-
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
111
|
-
)
|
112
|
-
server_args.chunked_prefill_size = None
|
113
|
-
server_args.mem_fraction_static *= 0.95
|
114
|
-
|
115
124
|
# Init componnets
|
116
125
|
min_per_gpu_memory = self.init_torch_distributed()
|
117
126
|
self.sampler = Sampler()
|
@@ -135,8 +144,8 @@ class ModelRunner:
|
|
135
144
|
if not self.server_args.enable_p2p_check:
|
136
145
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
137
146
|
|
138
|
-
if self.server_args.
|
139
|
-
nccl_init_method = f"tcp://{self.server_args.
|
147
|
+
if self.server_args.dist_init_addr:
|
148
|
+
nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
140
149
|
else:
|
141
150
|
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
142
151
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
@@ -222,6 +231,7 @@ class ModelRunner:
|
|
222
231
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
223
232
|
else None
|
224
233
|
)
|
234
|
+
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
|
225
235
|
self.is_generation = is_generation_model(
|
226
236
|
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
227
237
|
)
|
@@ -399,9 +409,11 @@ class ModelRunner:
|
|
399
409
|
4096,
|
400
410
|
)
|
401
411
|
|
412
|
+
device = "cuda"
|
402
413
|
self.req_to_token_pool = ReqToTokenPool(
|
403
|
-
max_num_reqs + 1,
|
404
|
-
self.model_config.context_len + 4,
|
414
|
+
size=max_num_reqs + 1,
|
415
|
+
max_context_len=self.model_config.context_len + 4,
|
416
|
+
device=device,
|
405
417
|
)
|
406
418
|
if (
|
407
419
|
self.model_config.attention_arch == AttentionArch.MLA
|
@@ -413,6 +425,7 @@ class ModelRunner:
|
|
413
425
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
414
426
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
415
427
|
layer_num=self.model_config.num_hidden_layers,
|
428
|
+
device=device,
|
416
429
|
)
|
417
430
|
else:
|
418
431
|
self.token_to_kv_pool = MHATokenToKVPool(
|
@@ -421,6 +434,7 @@ class ModelRunner:
|
|
421
434
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
422
435
|
head_dim=self.model_config.head_dim,
|
423
436
|
layer_num=self.model_config.num_hidden_layers,
|
437
|
+
device=device,
|
424
438
|
)
|
425
439
|
logger.info(
|
426
440
|
f"Memory pool end. "
|
@@ -445,6 +459,10 @@ class ModelRunner:
|
|
445
459
|
"Window attention is not supported in the triton attention backend. "
|
446
460
|
"Please use `--attention-backend flashinfer`."
|
447
461
|
)
|
462
|
+
assert not self.has_cross_attention, (
|
463
|
+
"Cross attention is not supported in the triton attention backend. "
|
464
|
+
"Please use `--attention-backend flashinfer`."
|
465
|
+
)
|
448
466
|
self.attn_backend = TritonAttnBackend(self)
|
449
467
|
else:
|
450
468
|
raise ValueError(
|
@@ -467,73 +485,59 @@ class ModelRunner:
|
|
467
485
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
468
486
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
469
487
|
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
476
|
-
return self.cuda_graph_runner.replay(batch)
|
477
|
-
|
478
|
-
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
488
|
+
def forward_decode(self, forward_batch: ForwardBatch):
|
489
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
490
|
+
forward_batch.batch_size
|
491
|
+
):
|
492
|
+
return self.cuda_graph_runner.replay(forward_batch)
|
479
493
|
|
480
494
|
return self.model.forward(
|
481
|
-
|
495
|
+
forward_batch.input_ids, forward_batch.positions, forward_batch
|
482
496
|
)
|
483
497
|
|
484
|
-
|
485
|
-
def forward_extend(self, batch: ScheduleBatch):
|
486
|
-
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
487
|
-
if self.server_args.lora_paths is not None:
|
488
|
-
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
|
489
|
-
|
498
|
+
def forward_extend(self, forward_batch: ForwardBatch):
|
490
499
|
if self.is_generation:
|
491
500
|
return self.model.forward(
|
492
|
-
|
501
|
+
forward_batch.input_ids, forward_batch.positions, forward_batch
|
493
502
|
)
|
494
503
|
else:
|
495
504
|
# Only embedding models have get_embedding parameter
|
496
505
|
return self.model.forward(
|
497
|
-
|
498
|
-
|
499
|
-
|
506
|
+
forward_batch.input_ids,
|
507
|
+
forward_batch.positions,
|
508
|
+
forward_batch,
|
500
509
|
get_embedding=True,
|
501
510
|
)
|
502
511
|
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
input_metadata.pixel_values,
|
511
|
-
input_metadata.image_sizes,
|
512
|
-
input_metadata.image_offsets,
|
513
|
-
)
|
512
|
+
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
513
|
+
if forward_batch.forward_mode.is_decode():
|
514
|
+
return self.forward_decode(forward_batch)
|
515
|
+
elif forward_batch.forward_mode.is_extend():
|
516
|
+
return self.forward_extend(forward_batch)
|
517
|
+
else:
|
518
|
+
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
514
519
|
|
515
|
-
def
|
516
|
-
|
520
|
+
def sample(
|
521
|
+
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
522
|
+
) -> torch.Tensor:
|
523
|
+
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
524
|
+
sampling_info = forward_batch.sampling_info
|
525
|
+
sampling_info.update_regex_vocab_mask()
|
526
|
+
sampling_info.update_penalties()
|
527
|
+
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
517
528
|
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
return self.forward_decode(batch)
|
522
|
-
elif batch.forward_mode.is_extend():
|
523
|
-
return self.forward_extend(batch)
|
524
|
-
else:
|
525
|
-
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
529
|
+
# Sample the next tokens.
|
530
|
+
next_token_ids = self.sampler(logits, sampling_info)
|
531
|
+
return next_token_ids
|
526
532
|
|
527
|
-
def
|
528
|
-
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
529
|
-
):
|
533
|
+
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
530
534
|
# Apply logit_bias
|
531
535
|
if sampling_info.logit_bias is not None:
|
532
536
|
logits.add_(sampling_info.logit_bias)
|
533
537
|
|
534
538
|
# min-token, presence, frequency
|
535
539
|
if sampling_info.linear_penalties is not None:
|
536
|
-
logits
|
540
|
+
logits.add_(sampling_info.linear_penalties)
|
537
541
|
|
538
542
|
# repetition
|
539
543
|
if sampling_info.scaling_penalties is not None:
|
@@ -549,20 +553,6 @@ class ModelRunner:
|
|
549
553
|
|
550
554
|
return logits
|
551
555
|
|
552
|
-
def sample(
|
553
|
-
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
554
|
-
) -> torch.Tensor:
|
555
|
-
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
556
|
-
batch.sampling_info.update_regex_vocab_mask(batch)
|
557
|
-
batch.sampling_info.update_penalties()
|
558
|
-
logits = self._apply_logits_bias(
|
559
|
-
logits_output.next_token_logits, batch.sampling_info
|
560
|
-
)
|
561
|
-
|
562
|
-
# Sample the next tokens.
|
563
|
-
next_token_ids = self.sampler(logits, batch.sampling_info)
|
564
|
-
return next_token_ids
|
565
|
-
|
566
556
|
|
567
557
|
@lru_cache()
|
568
558
|
def import_model_classes():
|
@@ -571,17 +561,25 @@ def import_model_classes():
|
|
571
561
|
package = importlib.import_module(package_name)
|
572
562
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
573
563
|
if not ispkg:
|
574
|
-
|
564
|
+
try:
|
565
|
+
module = importlib.import_module(name)
|
566
|
+
except Exception as e:
|
567
|
+
logger.warning(f"Ignore import error when loading {name}. " f"{e}")
|
568
|
+
continue
|
575
569
|
if hasattr(module, "EntryClass"):
|
576
570
|
entry = module.EntryClass
|
577
571
|
if isinstance(
|
578
572
|
entry, list
|
579
573
|
): # To support multiple model classes in one module
|
580
574
|
for tmp in entry:
|
581
|
-
assert
|
575
|
+
assert (
|
576
|
+
tmp.__name__ not in model_arch_name_to_cls
|
577
|
+
), f"Duplicated model implementation for {tmp.__name__}"
|
582
578
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
583
579
|
else:
|
584
|
-
assert
|
580
|
+
assert (
|
581
|
+
entry.__name__ not in model_arch_name_to_cls
|
582
|
+
), f"Duplicated model implementation for {entry.__name__}"
|
585
583
|
model_arch_name_to_cls[entry.__name__] = entry
|
586
584
|
|
587
585
|
return model_arch_name_to_cls
|
sglang/srt/models/baichuan.py
CHANGED
@@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.model_executor.forward_batch_info import
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
50
|
|
51
51
|
|
52
52
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
@@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module):
|
|
189
189
|
self,
|
190
190
|
positions: torch.Tensor,
|
191
191
|
hidden_states: torch.Tensor,
|
192
|
-
|
192
|
+
forward_batch: ForwardBatch,
|
193
193
|
) -> torch.Tensor:
|
194
194
|
qkv, _ = self.W_pack(hidden_states)
|
195
195
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
196
196
|
if self.postion_embedding != "ALIBI":
|
197
197
|
q, k = self.rotary_emb(positions, q, k)
|
198
|
-
attn_output = self.attn(q, k, v,
|
198
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
199
199
|
output, _ = self.o_proj(attn_output)
|
200
200
|
return output
|
201
201
|
|
@@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
|
237
237
|
self,
|
238
238
|
positions: torch.Tensor,
|
239
239
|
hidden_states: torch.Tensor,
|
240
|
-
|
240
|
+
forward_batch: ForwardBatch,
|
241
241
|
residual: Optional[torch.Tensor],
|
242
242
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
243
243
|
# Self Attention
|
@@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
|
249
249
|
hidden_states = self.self_attn(
|
250
250
|
positions=positions,
|
251
251
|
hidden_states=hidden_states,
|
252
|
-
|
252
|
+
forward_batch=forward_batch,
|
253
253
|
)
|
254
254
|
|
255
255
|
# Fully Connected
|
@@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module):
|
|
292
292
|
self,
|
293
293
|
input_ids: torch.Tensor,
|
294
294
|
positions: torch.Tensor,
|
295
|
-
|
295
|
+
forward_batch: ForwardBatch,
|
296
296
|
) -> torch.Tensor:
|
297
297
|
hidden_states = self.embed_tokens(input_ids)
|
298
298
|
residual = None
|
@@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module):
|
|
301
301
|
hidden_states, residual = layer(
|
302
302
|
positions,
|
303
303
|
hidden_states,
|
304
|
-
|
304
|
+
forward_batch,
|
305
305
|
residual,
|
306
306
|
)
|
307
307
|
hidden_states, _ = self.norm(hidden_states, residual)
|
@@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|
350
350
|
self,
|
351
351
|
input_ids: torch.Tensor,
|
352
352
|
positions: torch.Tensor,
|
353
|
-
|
353
|
+
forward_batch: ForwardBatch,
|
354
354
|
) -> torch.Tensor:
|
355
|
-
hidden_states = self.model(input_ids, positions,
|
355
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
356
356
|
return self.logits_processor(
|
357
|
-
input_ids, hidden_states, self.lm_head.weight,
|
357
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
358
358
|
)
|
359
359
|
|
360
360
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/chatglm.py
CHANGED
@@ -42,7 +42,7 @@ from sglang.srt.layers.linear import (
|
|
42
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
43
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
44
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
-
from sglang.srt.model_executor.forward_batch_info import
|
45
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
46
|
|
47
47
|
LoraConfig = None
|
48
48
|
|
@@ -118,7 +118,7 @@ class GLMAttention(nn.Module):
|
|
118
118
|
self,
|
119
119
|
hidden_states: torch.Tensor,
|
120
120
|
position_ids: torch.Tensor,
|
121
|
-
|
121
|
+
forward_batch: ForwardBatch,
|
122
122
|
) -> torch.Tensor:
|
123
123
|
qkv, _ = self.query_key_value(hidden_states)
|
124
124
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
@@ -127,7 +127,7 @@ class GLMAttention(nn.Module):
|
|
127
127
|
q,
|
128
128
|
k,
|
129
129
|
v,
|
130
|
-
|
130
|
+
forward_batch,
|
131
131
|
)
|
132
132
|
attn_output, _ = self.dense(context_layer)
|
133
133
|
return attn_output
|
@@ -220,7 +220,7 @@ class GLMBlock(nn.Module):
|
|
220
220
|
self,
|
221
221
|
hidden_states: torch.Tensor,
|
222
222
|
position_ids: torch.Tensor,
|
223
|
-
|
223
|
+
forward_batch: ForwardBatch,
|
224
224
|
) -> torch.Tensor:
|
225
225
|
# hidden_states: [num_tokens, h]
|
226
226
|
# Layer norm at the beginning of the transformer layer.
|
@@ -229,7 +229,7 @@ class GLMBlock(nn.Module):
|
|
229
229
|
attention_output = self.self_attention(
|
230
230
|
hidden_states=layernorm_output,
|
231
231
|
position_ids=position_ids,
|
232
|
-
|
232
|
+
forward_batch=forward_batch,
|
233
233
|
)
|
234
234
|
|
235
235
|
# Residual connection.
|
@@ -288,14 +288,14 @@ class GLMTransformer(nn.Module):
|
|
288
288
|
self,
|
289
289
|
hidden_states: torch.Tensor,
|
290
290
|
position_ids: torch.Tensor,
|
291
|
-
|
291
|
+
forward_batch: ForwardBatch,
|
292
292
|
) -> torch.Tensor:
|
293
293
|
for i in range(self.num_layers):
|
294
294
|
layer = self.layers[i]
|
295
295
|
hidden_states = layer(
|
296
296
|
hidden_states=hidden_states,
|
297
297
|
position_ids=position_ids,
|
298
|
-
|
298
|
+
forward_batch=forward_batch,
|
299
299
|
)
|
300
300
|
# Final layer norm.
|
301
301
|
if self.post_layer_norm:
|
@@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module):
|
|
328
328
|
self,
|
329
329
|
input_ids: torch.Tensor,
|
330
330
|
position_ids: torch.Tensor,
|
331
|
-
|
331
|
+
forward_batch: ForwardBatch,
|
332
332
|
) -> torch.Tensor:
|
333
333
|
inputs_embeds = self.embedding(input_ids)
|
334
334
|
|
@@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module):
|
|
336
336
|
hidden_states = self.encoder(
|
337
337
|
hidden_states=inputs_embeds,
|
338
338
|
position_ids=position_ids,
|
339
|
-
|
339
|
+
forward_batch=forward_batch,
|
340
340
|
)
|
341
341
|
return hidden_states
|
342
342
|
|
@@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module):
|
|
376
376
|
self,
|
377
377
|
input_ids: torch.Tensor,
|
378
378
|
positions: torch.Tensor,
|
379
|
-
|
379
|
+
forward_batch: ForwardBatch,
|
380
380
|
) -> torch.Tensor:
|
381
|
-
hidden_states = self.transformer(input_ids, positions,
|
381
|
+
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
382
382
|
return self.logits_processor(
|
383
|
-
input_ids, hidden_states, self.lm_head.weight,
|
383
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
384
384
|
)
|
385
385
|
|
386
386
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/commandr.py
CHANGED
@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import (
|
|
63
63
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
64
64
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
65
65
|
from sglang.srt.layers.radix_attention import RadixAttention
|
66
|
-
from sglang.srt.model_executor.forward_batch_info import
|
66
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
67
67
|
from sglang.srt.utils import set_weight_attrs
|
68
68
|
|
69
69
|
|
@@ -220,14 +220,14 @@ class CohereAttention(nn.Module):
|
|
220
220
|
self,
|
221
221
|
positions: torch.Tensor,
|
222
222
|
hidden_states: torch.Tensor,
|
223
|
-
|
223
|
+
forward_batch: ForwardBatch,
|
224
224
|
) -> torch.Tensor:
|
225
225
|
qkv, _ = self.qkv_proj(hidden_states)
|
226
226
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
227
227
|
if self.use_qk_norm:
|
228
228
|
q, k = self._apply_qk_norm(q, k)
|
229
229
|
q, k = self.rotary_emb(positions, q, k)
|
230
|
-
attn_output = self.attn(q, k, v,
|
230
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
231
231
|
output, _ = self.o_proj(attn_output)
|
232
232
|
return output
|
233
233
|
|
@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module):
|
|
255
255
|
self,
|
256
256
|
positions: torch.Tensor,
|
257
257
|
hidden_states: torch.Tensor,
|
258
|
-
|
258
|
+
forward_batch: ForwardBatch,
|
259
259
|
residual: Optional[torch.Tensor],
|
260
260
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
261
261
|
# Self Attention
|
@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module):
|
|
264
264
|
hidden_states_attention = self.self_attn(
|
265
265
|
positions=positions,
|
266
266
|
hidden_states=hidden_states,
|
267
|
-
|
267
|
+
forward_batch=forward_batch,
|
268
268
|
)
|
269
269
|
hidden_states_mlp = self.mlp(hidden_states)
|
270
270
|
# Add everything together
|
@@ -299,7 +299,7 @@ class CohereModel(nn.Module):
|
|
299
299
|
self,
|
300
300
|
input_ids: torch.Tensor,
|
301
301
|
positions: torch.Tensor,
|
302
|
-
|
302
|
+
forward_batch: ForwardBatch,
|
303
303
|
) -> torch.Tensor:
|
304
304
|
hidden_states = self.embed_tokens(input_ids)
|
305
305
|
residual = None
|
@@ -308,7 +308,7 @@ class CohereModel(nn.Module):
|
|
308
308
|
hidden_states, residual = layer(
|
309
309
|
positions,
|
310
310
|
hidden_states,
|
311
|
-
|
311
|
+
forward_batch,
|
312
312
|
residual,
|
313
313
|
)
|
314
314
|
hidden_states, _ = self.norm(hidden_states, residual)
|
@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module):
|
|
333
333
|
self,
|
334
334
|
input_ids: torch.Tensor,
|
335
335
|
positions: torch.Tensor,
|
336
|
-
|
336
|
+
forward_batch: ForwardBatch,
|
337
337
|
) -> torch.Tensor:
|
338
338
|
hidden_states = self.model(
|
339
339
|
input_ids,
|
340
340
|
positions,
|
341
|
-
|
341
|
+
forward_batch,
|
342
342
|
)
|
343
343
|
return self.logits_processor(
|
344
|
-
input_ids, hidden_states, self.model.embed_tokens.weight,
|
344
|
+
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
345
345
|
)
|
346
346
|
|
347
347
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/dbrx.py
CHANGED
@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import (
|
|
44
44
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
45
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
46
46
|
from sglang.srt.layers.radix_attention import RadixAttention
|
47
|
-
from sglang.srt.model_executor.forward_batch_info import
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
48
|
from sglang.srt.utils import set_weight_attrs
|
49
49
|
|
50
50
|
|
@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module):
|
|
249
249
|
self,
|
250
250
|
position_ids: torch.Tensor,
|
251
251
|
hidden_states: torch.Tensor,
|
252
|
-
|
252
|
+
forward_batch: ForwardBatch,
|
253
253
|
) -> torch.Tensor:
|
254
254
|
qkv, _ = self.Wqkv(hidden_states)
|
255
255
|
if self.clip_qkv is not None:
|
256
256
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
257
257
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
258
258
|
q, k = self.rotary_emb(position_ids, q, k)
|
259
|
-
attn_output = self.attn(q, k, v,
|
259
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
260
260
|
hidden_states, _ = self.out_proj(attn_output)
|
261
261
|
return hidden_states
|
262
262
|
|
@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module):
|
|
278
278
|
self,
|
279
279
|
position_ids: torch.Tensor,
|
280
280
|
hidden_states: torch.Tensor,
|
281
|
-
|
281
|
+
forward_batch: ForwardBatch,
|
282
282
|
) -> torch.Tensor:
|
283
283
|
residual = hidden_states
|
284
284
|
hidden_states = self.norm_1(hidden_states)
|
285
285
|
x = self.attn(
|
286
286
|
position_ids=position_ids,
|
287
287
|
hidden_states=hidden_states,
|
288
|
-
|
288
|
+
forward_batch=forward_batch,
|
289
289
|
)
|
290
290
|
hidden_states = residual + x
|
291
291
|
residual = hidden_states
|
@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module):
|
|
310
310
|
self,
|
311
311
|
position_ids: torch.Tensor,
|
312
312
|
hidden_states: torch.Tensor,
|
313
|
-
|
313
|
+
forward_batch: ForwardBatch,
|
314
314
|
) -> torch.Tensor:
|
315
315
|
hidden_states, residual = self.norm_attn_norm(
|
316
316
|
position_ids=position_ids,
|
317
317
|
hidden_states=hidden_states,
|
318
|
-
|
318
|
+
forward_batch=forward_batch,
|
319
319
|
)
|
320
320
|
hidden_states = self.ffn(hidden_states)
|
321
321
|
hidden_states = hidden_states + residual
|
@@ -349,7 +349,7 @@ class DbrxModel(nn.Module):
|
|
349
349
|
self,
|
350
350
|
input_ids: torch.Tensor,
|
351
351
|
position_ids: torch.Tensor,
|
352
|
-
|
352
|
+
forward_batch: ForwardBatch,
|
353
353
|
input_embeds: torch.Tensor = None,
|
354
354
|
) -> torch.Tensor:
|
355
355
|
if input_embeds is None:
|
@@ -358,7 +358,7 @@ class DbrxModel(nn.Module):
|
|
358
358
|
hidden_states = input_embeds
|
359
359
|
for i in range(len(self.blocks)):
|
360
360
|
block = self.blocks[i]
|
361
|
-
hidden_states = block(position_ids, hidden_states,
|
361
|
+
hidden_states = block(position_ids, hidden_states, forward_batch)
|
362
362
|
hidden_states = self.norm_f(hidden_states)
|
363
363
|
return hidden_states
|
364
364
|
|
@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module):
|
|
388
388
|
self,
|
389
389
|
input_ids: torch.Tensor,
|
390
390
|
positions: torch.Tensor,
|
391
|
-
|
391
|
+
forward_batch: ForwardBatch,
|
392
392
|
) -> torch.Tensor:
|
393
|
-
hidden_states = self.transformer(input_ids, positions,
|
393
|
+
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
394
394
|
return self.logits_processor(
|
395
|
-
input_ids, hidden_states, self.lm_head.weight,
|
395
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
396
396
|
)
|
397
397
|
|
398
398
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|