sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +41 -5
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
- sglang/srt/layers/parameter.py +2 -1
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/fp8.py +6 -3
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +25 -2
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +277 -178
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +206 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +37 -15
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/sampling_batch_info.py +139 -4
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +57 -14
- sglang/srt/utils.py +103 -65
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,38 @@
|
|
1
|
+
import json
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from functools import lru_cache
|
4
|
+
from typing import Any, Dict, List, Optional
|
5
|
+
|
6
|
+
import dill
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
@lru_cache(maxsize=None)
|
11
|
+
def _cache_from_str(json_str: str):
|
12
|
+
"""Deserialize a json string to a Callable object.
|
13
|
+
This function is cached to avoid redundant deserialization.
|
14
|
+
"""
|
15
|
+
data = json.loads(json_str)
|
16
|
+
return dill.loads(bytes.fromhex(data["callable"]))
|
17
|
+
|
18
|
+
|
19
|
+
class CustomLogitProcessor(ABC):
|
20
|
+
"""Abstract base class for callable functions."""
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def __call__(
|
24
|
+
self,
|
25
|
+
logits: torch.Tensor,
|
26
|
+
custom_param_list: Optional[List[Dict[str, Any]]] = None,
|
27
|
+
) -> torch.Tensor:
|
28
|
+
"""Define the callable behavior."""
|
29
|
+
raise NotImplementedError
|
30
|
+
|
31
|
+
def to_str(self) -> str:
|
32
|
+
"""Serialize the callable function to a JSON-compatible string."""
|
33
|
+
return json.dumps({"callable": dill.dumps(self).hex()})
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def from_str(cls, json_str: str):
|
37
|
+
"""Deserialize a callable function from a JSON string."""
|
38
|
+
return _cache_from_str(json_str)
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import dataclasses
|
4
4
|
import logging
|
5
5
|
import threading
|
6
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
@@ -14,6 +14,7 @@ if is_cuda:
|
|
14
14
|
from sgl_kernel import sampling_scaling_penalties
|
15
15
|
|
16
16
|
import sglang.srt.sampling.penaltylib as penaltylib
|
17
|
+
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
17
18
|
|
18
19
|
logger = logging.getLogger(__name__)
|
19
20
|
|
@@ -36,6 +37,9 @@ class SamplingBatchInfo:
|
|
36
37
|
# Dispatch in CUDA graph
|
37
38
|
need_min_p_sampling: bool
|
38
39
|
|
40
|
+
# Whether any request has custom logit processor
|
41
|
+
has_custom_logit_processor: bool
|
42
|
+
|
39
43
|
# Bias Tensors
|
40
44
|
vocab_size: int
|
41
45
|
grammars: Optional[List] = None
|
@@ -52,6 +56,14 @@ class SamplingBatchInfo:
|
|
52
56
|
# Device
|
53
57
|
device: str = "cuda"
|
54
58
|
|
59
|
+
# Custom Parameters
|
60
|
+
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
61
|
+
|
62
|
+
# Custom Logit Processor
|
63
|
+
custom_logit_processor: Optional[
|
64
|
+
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
65
|
+
] = None
|
66
|
+
|
55
67
|
@classmethod
|
56
68
|
def from_schedule_batch(
|
57
69
|
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
@@ -76,6 +88,39 @@ class SamplingBatchInfo:
|
|
76
88
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
77
89
|
).to(device, non_blocking=True)
|
78
90
|
|
91
|
+
# Check if any request has custom logit processor
|
92
|
+
has_custom_logit_processor = (
|
93
|
+
batch.enable_custom_logit_processor # check the flag first.
|
94
|
+
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
95
|
+
)
|
96
|
+
|
97
|
+
if has_custom_logit_processor:
|
98
|
+
# Merge the same type of custom logit processors together
|
99
|
+
processor_dict = {}
|
100
|
+
for i, r in enumerate(reqs):
|
101
|
+
if r.custom_logit_processor is None:
|
102
|
+
continue
|
103
|
+
processor_str = r.custom_logit_processor
|
104
|
+
if processor_str not in processor_dict:
|
105
|
+
processor_dict[processor_str] = []
|
106
|
+
processor_dict[processor_str].append(i)
|
107
|
+
|
108
|
+
merged_custom_logit_processor = {
|
109
|
+
hash(processor_str): (
|
110
|
+
# The deserialized custom logit processor object
|
111
|
+
CustomLogitProcessor.from_str(processor_str),
|
112
|
+
# The mask tensor for the requests that use this custom logit processor
|
113
|
+
torch.zeros(len(reqs), dtype=torch.bool)
|
114
|
+
.scatter_(0, torch.tensor(true_indices), True)
|
115
|
+
.to(device, non_blocking=True),
|
116
|
+
)
|
117
|
+
for processor_str, true_indices in processor_dict.items()
|
118
|
+
}
|
119
|
+
custom_params = [r.sampling_params.custom_params for r in reqs]
|
120
|
+
else:
|
121
|
+
merged_custom_logit_processor = None
|
122
|
+
custom_params = None
|
123
|
+
|
79
124
|
ret = cls(
|
80
125
|
temperatures=temperatures,
|
81
126
|
top_ps=top_ps,
|
@@ -83,8 +128,11 @@ class SamplingBatchInfo:
|
|
83
128
|
min_ps=min_ps,
|
84
129
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
85
130
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
131
|
+
has_custom_logit_processor=has_custom_logit_processor,
|
86
132
|
vocab_size=vocab_size,
|
87
133
|
device=device,
|
134
|
+
custom_params=custom_params,
|
135
|
+
custom_logit_processor=merged_custom_logit_processor,
|
88
136
|
)
|
89
137
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
90
138
|
|
@@ -184,6 +232,8 @@ class SamplingBatchInfo:
|
|
184
232
|
|
185
233
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
186
234
|
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
235
|
+
if self.has_custom_logit_processor:
|
236
|
+
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
|
187
237
|
|
188
238
|
for item in [
|
189
239
|
"temperatures",
|
@@ -196,6 +246,27 @@ class SamplingBatchInfo:
|
|
196
246
|
if value is not None: # logit_bias can be None
|
197
247
|
setattr(self, item, value[new_indices])
|
198
248
|
|
249
|
+
def _filter_batch_custom_logit_processor(
|
250
|
+
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
251
|
+
):
|
252
|
+
"""Filter the custom logit processor and custom params"""
|
253
|
+
|
254
|
+
self.custom_logit_processor = {
|
255
|
+
k: (p, mask[new_indices])
|
256
|
+
for k, (p, mask) in self.custom_logit_processor.items()
|
257
|
+
if any(
|
258
|
+
mask[new_indices]
|
259
|
+
) # ignore the custom logit processor whose mask is all False
|
260
|
+
}
|
261
|
+
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
262
|
+
|
263
|
+
# If the custom logit processor is an empty dict, set the flag to False,
|
264
|
+
# and set the custom logit processor and custom params to None.
|
265
|
+
if len(self.custom_logit_processor) == 0:
|
266
|
+
self.custom_logit_processor = None
|
267
|
+
self.custom_params = None
|
268
|
+
self.has_custom_logit_processor = False
|
269
|
+
|
199
270
|
@staticmethod
|
200
271
|
def merge_bias_tensor(
|
201
272
|
lhs: torch.Tensor,
|
@@ -221,9 +292,76 @@ class SamplingBatchInfo:
|
|
221
292
|
|
222
293
|
return None
|
223
294
|
|
295
|
+
@staticmethod
|
296
|
+
def merge_custom_logit_processor(
|
297
|
+
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
298
|
+
rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
299
|
+
bs1: int,
|
300
|
+
bs2: int,
|
301
|
+
device: str,
|
302
|
+
):
|
303
|
+
if lhs is None and rhs is None:
|
304
|
+
return None
|
305
|
+
lhs, rhs = lhs or {}, rhs or {}
|
306
|
+
|
307
|
+
keys = set(lhs.keys()).union(set(rhs.keys()))
|
308
|
+
merged_dict = {}
|
309
|
+
|
310
|
+
for k in keys:
|
311
|
+
# Get the logit processor object
|
312
|
+
processor = lhs[k][0] if k in lhs else rhs[k][0]
|
313
|
+
# Get and merge the mask tensors from the two dicts
|
314
|
+
left_mask = (
|
315
|
+
lhs[k][1]
|
316
|
+
if k in lhs
|
317
|
+
else torch.zeros(bs1, dtype=torch.bool, device=device)
|
318
|
+
)
|
319
|
+
right_mask = (
|
320
|
+
rhs[k][1]
|
321
|
+
if k in rhs
|
322
|
+
else torch.zeros(bs2, dtype=torch.bool, device=device)
|
323
|
+
)
|
324
|
+
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
325
|
+
|
326
|
+
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
|
327
|
+
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
|
328
|
+
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
|
329
|
+
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
|
330
|
+
f"\n{lhs=}\n{rhs=}"
|
331
|
+
)
|
332
|
+
|
333
|
+
return merged_dict
|
334
|
+
|
224
335
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
225
336
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
226
337
|
|
338
|
+
# Merge the logit bias tensor
|
339
|
+
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
340
|
+
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
341
|
+
)
|
342
|
+
# Merge the custom logit processors and custom params lists
|
343
|
+
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
344
|
+
# Merge the custom logit processors
|
345
|
+
self.custom_logit_processor = (
|
346
|
+
SamplingBatchInfo.merge_custom_logit_processor(
|
347
|
+
self.custom_logit_processor,
|
348
|
+
other.custom_logit_processor,
|
349
|
+
len(self),
|
350
|
+
len(other),
|
351
|
+
self.device,
|
352
|
+
)
|
353
|
+
)
|
354
|
+
# Merge the custom params lists
|
355
|
+
self.custom_params = self.custom_params or [None] * len(self)
|
356
|
+
other.custom_params = other.custom_params or [None] * len(other)
|
357
|
+
self.custom_params.extend(other.custom_params)
|
358
|
+
|
359
|
+
# Set the flag to True if any of the two has custom logit processor
|
360
|
+
self.has_custom_logit_processor = True
|
361
|
+
|
362
|
+
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
|
363
|
+
# please make sure any merge operation with len(self) or len(other) is done before
|
364
|
+
# the merge operation of the temperatures tensor below.
|
227
365
|
for item in [
|
228
366
|
"temperatures",
|
229
367
|
"top_ps",
|
@@ -235,9 +373,6 @@ class SamplingBatchInfo:
|
|
235
373
|
setattr(self, item, torch.concat([self_val, other_val]))
|
236
374
|
|
237
375
|
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
238
|
-
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
239
|
-
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
240
|
-
)
|
241
376
|
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
242
377
|
|
243
378
|
def apply_logits_bias(self, logits: torch.Tensor):
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Sampling parameters for text generation."""
|
15
15
|
|
16
|
-
from typing import List, Optional, Union
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
17
17
|
|
18
18
|
_SAMPLING_EPS = 1e-6
|
19
19
|
|
@@ -48,6 +48,7 @@ class SamplingParams:
|
|
48
48
|
no_stop_trim: bool = False,
|
49
49
|
ignore_eos: bool = False,
|
50
50
|
skip_special_tokens: bool = True,
|
51
|
+
custom_params: Optional[Dict[str, Any]] = None,
|
51
52
|
) -> None:
|
52
53
|
self.temperature = temperature
|
53
54
|
self.top_p = top_p
|
@@ -71,6 +72,7 @@ class SamplingParams:
|
|
71
72
|
self.json_schema = json_schema
|
72
73
|
self.ebnf = ebnf
|
73
74
|
self.no_stop_trim = no_stop_trim
|
75
|
+
self.custom_params = custom_params
|
74
76
|
|
75
77
|
# Process some special cases
|
76
78
|
if self.temperature < _SAMPLING_EPS:
|