sglang 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +26 -0
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +34 -16
- sglang/global_config.py +1 -0
- sglang/lang/chat_template.py +41 -6
- sglang/lang/interpreter.py +5 -1
- sglang/lang/ir.py +61 -25
- sglang/srt/constrained/__init__.py +3 -2
- sglang/srt/hf_transformers_utils.py +7 -3
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +181 -167
- sglang/srt/layers/logits_processor.py +55 -19
- sglang/srt/layers/radix_attention.py +24 -27
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/infer_batch.py +2 -2
- sglang/srt/managers/controller/manager_single.py +1 -1
- sglang/srt/managers/controller/model_runner.py +27 -15
- sglang/srt/managers/controller/tp_worker.py +31 -14
- sglang/srt/managers/detokenizer_manager.py +4 -2
- sglang/srt/managers/io_struct.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +6 -0
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/llama2.py +3 -3
- sglang/srt/models/llama_classification.py +10 -7
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/openai_api_adapter.py +2 -2
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +17 -8
- sglang/srt/server_args.py +14 -16
- sglang/srt/utils.py +68 -35
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/METADATA +19 -13
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/RECORD +38 -35
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/WHEEL +0 -0
- {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,10 @@
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import torch
|
5
|
+
from flashinfer.cascade import merge_state
|
5
6
|
from torch import nn
|
6
7
|
|
7
8
|
from sglang.global_config import global_config
|
8
|
-
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
9
9
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
10
10
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
11
11
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
@@ -13,18 +13,22 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
|
|
13
13
|
|
14
14
|
class RadixAttention(nn.Module):
|
15
15
|
def __init__(
|
16
|
-
self,
|
17
|
-
|
16
|
+
self,
|
17
|
+
num_heads: int,
|
18
|
+
head_dim: int,
|
19
|
+
scaling: float,
|
20
|
+
num_kv_heads: int,
|
21
|
+
layer_id: int,
|
22
|
+
logit_cap: int = -1,
|
18
23
|
):
|
19
24
|
super().__init__()
|
20
25
|
self.tp_q_head_num = num_heads
|
21
26
|
self.tp_k_head_num = num_kv_heads
|
22
27
|
self.tp_v_head_num = num_kv_heads
|
23
28
|
self.head_dim = head_dim
|
29
|
+
self.scaling = scaling
|
24
30
|
self.layer_id = layer_id
|
25
31
|
|
26
|
-
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
27
|
-
|
28
32
|
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
29
33
|
|
30
34
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
@@ -32,29 +36,17 @@ class RadixAttention(nn.Module):
|
|
32
36
|
self.extend_forward = self.prefill_forward_flashinfer
|
33
37
|
self.decode_forward = self.decode_forward_flashinfer
|
34
38
|
# flashinfer now accepts float logit_cap argument
|
35
|
-
self.logit_cap = logit_cap if logit_cap > 0 else 0
|
39
|
+
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
36
40
|
else:
|
37
41
|
self.prefill_forward = self.prefill_forward_triton
|
38
42
|
self.extend_forward = self.extend_forward_triton
|
39
43
|
self.decode_forward = self.decode_forward_triton
|
40
|
-
self.logit_cap = logit_cap
|
44
|
+
self.logit_cap = logit_cap if logit_cap is not None else 0
|
41
45
|
|
42
46
|
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
q.view(-1, self.tp_q_head_num, self.head_dim),
|
47
|
-
k,
|
48
|
-
v,
|
49
|
-
o.view(-1, self.tp_q_head_num, self.head_dim),
|
50
|
-
input_metadata.start_loc,
|
51
|
-
input_metadata.seq_lens,
|
52
|
-
input_metadata.max_seq_len,
|
53
|
-
self.logit_cap,
|
54
|
-
)
|
55
|
-
self.store_kv_cache(k, v, input_metadata)
|
56
|
-
|
57
|
-
return o
|
47
|
+
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
|
48
|
+
# See the extend_forward_xxx functions.
|
49
|
+
raise NotImplementedError()
|
58
50
|
|
59
51
|
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
60
52
|
o = torch.empty_like(q)
|
@@ -75,7 +67,8 @@ class RadixAttention(nn.Module):
|
|
75
67
|
input_metadata.extend_seq_lens,
|
76
68
|
input_metadata.max_seq_len,
|
77
69
|
input_metadata.max_extend_len,
|
78
|
-
self.
|
70
|
+
sm_scale=self.scaling,
|
71
|
+
logit_cap=self.logit_cap,
|
79
72
|
)
|
80
73
|
|
81
74
|
return o
|
@@ -96,18 +89,19 @@ class RadixAttention(nn.Module):
|
|
96
89
|
input_metadata.max_seq_len,
|
97
90
|
input_metadata.other_kv_index,
|
98
91
|
input_metadata.total_num_tokens,
|
99
|
-
self.
|
92
|
+
sm_scale=self.scaling,
|
93
|
+
logit_cap=self.logit_cap,
|
100
94
|
)
|
101
95
|
|
102
96
|
return o
|
103
97
|
|
104
98
|
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
105
|
-
self.store_kv_cache(k, v, input_metadata)
|
106
|
-
|
107
99
|
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
108
100
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
109
101
|
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
110
102
|
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
103
|
+
causal=True,
|
104
|
+
sm_scale=self.scaling,
|
111
105
|
logits_soft_cap=self.logit_cap,
|
112
106
|
)
|
113
107
|
|
@@ -118,12 +112,14 @@ class RadixAttention(nn.Module):
|
|
118
112
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
119
113
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
120
114
|
causal=False,
|
115
|
+
sm_scale=self.scaling,
|
121
116
|
logits_soft_cap=self.logit_cap,
|
122
117
|
)
|
123
118
|
|
124
|
-
from flashinfer.cascade import merge_state
|
125
119
|
o, _ = merge_state(o1, s1, o2, s2)
|
126
120
|
|
121
|
+
self.store_kv_cache(k, v, input_metadata)
|
122
|
+
|
127
123
|
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
128
124
|
torch.cuda.synchronize()
|
129
125
|
|
@@ -135,6 +131,7 @@ class RadixAttention(nn.Module):
|
|
135
131
|
o = input_metadata.flashinfer_decode_wrapper.forward(
|
136
132
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
137
133
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
134
|
+
sm_scale=self.scaling,
|
138
135
|
logits_soft_cap=self.logit_cap,
|
139
136
|
)
|
140
137
|
|
@@ -176,6 +176,7 @@ def _token_att_m_fwd(
|
|
176
176
|
B_Start_Loc,
|
177
177
|
B_Seqlen,
|
178
178
|
max_len_in_batch,
|
179
|
+
sm_scale,
|
179
180
|
logit_cap,
|
180
181
|
):
|
181
182
|
BLOCK = 32
|
@@ -183,7 +184,6 @@ def _token_att_m_fwd(
|
|
183
184
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
184
185
|
assert Lq == Lk
|
185
186
|
assert Lk in {16, 32, 64, 128, 256}
|
186
|
-
sm_scale = 1.0 / (Lk**0.5)
|
187
187
|
|
188
188
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
189
189
|
|
@@ -317,6 +317,7 @@ def token_attention_fwd(
|
|
317
317
|
max_len_in_batch,
|
318
318
|
other_kv_index,
|
319
319
|
total_num_tokens,
|
320
|
+
sm_scale=None,
|
320
321
|
logit_cap=-1,
|
321
322
|
att_m=None,
|
322
323
|
):
|
@@ -324,6 +325,7 @@ def token_attention_fwd(
|
|
324
325
|
att_m = torch.empty(
|
325
326
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
326
327
|
)
|
328
|
+
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
327
329
|
|
328
330
|
_token_att_m_fwd(
|
329
331
|
q,
|
@@ -334,6 +336,7 @@ def token_attention_fwd(
|
|
334
336
|
b_start_loc,
|
335
337
|
b_seq_len,
|
336
338
|
max_len_in_batch,
|
339
|
+
sm_scale,
|
337
340
|
logit_cap,
|
338
341
|
)
|
339
342
|
_token_softmax_reducev_fwd(
|
@@ -3,7 +3,7 @@
|
|
3
3
|
import warnings
|
4
4
|
from dataclasses import dataclass
|
5
5
|
from enum import IntEnum, auto
|
6
|
-
from typing import List
|
6
|
+
from typing import List, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
@@ -31,7 +31,7 @@ class BaseFinishReason:
|
|
31
31
|
|
32
32
|
|
33
33
|
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
34
|
-
def __init__(self, matched: int
|
34
|
+
def __init__(self, matched: Union[int, List[int]]):
|
35
35
|
super().__init__()
|
36
36
|
self.matched = matched
|
37
37
|
|
@@ -127,7 +127,7 @@ class InputMetadata:
|
|
127
127
|
num_qo_heads,
|
128
128
|
num_kv_heads,
|
129
129
|
head_dim,
|
130
|
-
1
|
130
|
+
1,
|
131
131
|
)
|
132
132
|
else:
|
133
133
|
self.flashinfer_decode_wrapper.end_forward()
|
@@ -140,7 +140,7 @@ class InputMetadata:
|
|
140
140
|
head_dim,
|
141
141
|
1,
|
142
142
|
pos_encoding_mode="NONE",
|
143
|
-
data_type=self.token_to_kv_pool.kv_data[0].dtype
|
143
|
+
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
144
144
|
)
|
145
145
|
|
146
146
|
def init_extend_args(self):
|
@@ -228,7 +228,7 @@ class InputMetadata:
|
|
228
228
|
ret.init_flashinfer_args(
|
229
229
|
model_runner.model_config.num_attention_heads // tp_size,
|
230
230
|
model_runner.model_config.get_num_kv_heads(tp_size),
|
231
|
-
model_runner.model_config.head_dim
|
231
|
+
model_runner.model_config.head_dim,
|
232
232
|
)
|
233
233
|
|
234
234
|
return ret
|
@@ -259,7 +259,10 @@ class ModelRunner:
|
|
259
259
|
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
260
260
|
torch.cuda.set_device(self.gpu_id)
|
261
261
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
262
|
-
|
262
|
+
|
263
|
+
if not server_args.enable_p2p_check:
|
264
|
+
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
265
|
+
|
263
266
|
if server_args.nccl_init_addr:
|
264
267
|
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
265
268
|
else:
|
@@ -269,7 +272,7 @@ class ModelRunner:
|
|
269
272
|
world_size=self.tp_size,
|
270
273
|
rank=self.tp_rank,
|
271
274
|
local_rank=self.gpu_id,
|
272
|
-
distributed_init_method=nccl_init_method
|
275
|
+
distributed_init_method=nccl_init_method,
|
273
276
|
)
|
274
277
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
275
278
|
total_gpu_memory = get_available_gpu_memory(
|
@@ -323,7 +326,7 @@ class ModelRunner:
|
|
323
326
|
device_config=device_config,
|
324
327
|
load_config=load_config,
|
325
328
|
lora_config=None,
|
326
|
-
|
329
|
+
multimodal_config=None,
|
327
330
|
parallel_config=None,
|
328
331
|
scheduler_config=None,
|
329
332
|
cache_config=None,
|
@@ -341,7 +344,13 @@ class ModelRunner:
|
|
341
344
|
)
|
342
345
|
head_dim = self.model_config.head_dim
|
343
346
|
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
344
|
-
cell_size =
|
347
|
+
cell_size = (
|
348
|
+
head_num
|
349
|
+
* head_dim
|
350
|
+
* self.model_config.num_hidden_layers
|
351
|
+
* 2
|
352
|
+
* torch._utils._element_size(self.dtype)
|
353
|
+
)
|
345
354
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
346
355
|
1 - self.mem_fraction_static
|
347
356
|
)
|
@@ -384,33 +393,36 @@ class ModelRunner:
|
|
384
393
|
def init_flash_infer(self):
|
385
394
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
386
395
|
from flashinfer import (
|
387
|
-
BatchPrefillWithRaggedKVCacheWrapper,
|
388
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
389
396
|
BatchDecodeWithPagedKVCacheWrapper,
|
397
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
398
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
390
399
|
)
|
391
400
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
392
401
|
|
393
402
|
if not _grouped_size_compiled_for_decode_kernels(
|
394
403
|
self.model_config.num_attention_heads // self.tp_size,
|
395
|
-
self.model_config.get_num_kv_heads(self.tp_size)
|
404
|
+
self.model_config.get_num_kv_heads(self.tp_size),
|
405
|
+
):
|
396
406
|
use_tensor_cores = True
|
397
407
|
else:
|
398
408
|
use_tensor_cores = False
|
399
409
|
|
400
410
|
workspace_buffers = torch.empty(
|
401
|
-
|
411
|
+
2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
402
412
|
)
|
403
|
-
self.flashinfer_prefill_wrapper_ragged =
|
404
|
-
workspace_buffers[0], "NHD"
|
413
|
+
self.flashinfer_prefill_wrapper_ragged = (
|
414
|
+
BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
|
405
415
|
)
|
406
416
|
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
407
417
|
workspace_buffers[1], "NHD"
|
408
418
|
)
|
409
419
|
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
410
|
-
workspace_buffers[
|
420
|
+
workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
411
421
|
)
|
412
422
|
else:
|
413
|
-
self.flashinfer_prefill_wrapper_ragged =
|
423
|
+
self.flashinfer_prefill_wrapper_ragged = (
|
424
|
+
self.flashinfer_prefill_wrapper_paged
|
425
|
+
) = None
|
414
426
|
self.flashinfer_decode_wrapper = None
|
415
427
|
|
416
428
|
@torch.inference_mode()
|
@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
|
|
34
34
|
from sglang.srt.model_config import ModelConfig
|
35
35
|
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
36
36
|
from sglang.srt.utils import (
|
37
|
+
connect_rpyc_service,
|
37
38
|
get_int_token_logit_bias,
|
38
39
|
is_multimodal_model,
|
39
40
|
set_random_seed,
|
40
41
|
start_rpyc_service_process,
|
41
|
-
connect_rpyc_service,
|
42
42
|
suppress_other_loggers,
|
43
43
|
)
|
44
44
|
from sglang.utils import get_exception_traceback
|
@@ -368,9 +368,11 @@ class ModelTpServer:
|
|
368
368
|
if (
|
369
369
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
370
370
|
< available_size
|
371
|
-
and (
|
372
|
-
|
373
|
-
|
371
|
+
and (
|
372
|
+
req.extend_input_len + new_batch_input_tokens
|
373
|
+
<= self.max_prefill_tokens
|
374
|
+
or len(can_run_list) == 0
|
375
|
+
)
|
374
376
|
):
|
375
377
|
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
376
378
|
available_size += delta
|
@@ -452,7 +454,9 @@ class ModelTpServer:
|
|
452
454
|
next_token_ids,
|
453
455
|
].tolist()
|
454
456
|
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
455
|
-
output.normalized_prompt_logprobs =
|
457
|
+
output.normalized_prompt_logprobs = (
|
458
|
+
output.normalized_prompt_logprobs.tolist()
|
459
|
+
)
|
456
460
|
|
457
461
|
next_token_ids = next_token_ids.tolist()
|
458
462
|
else:
|
@@ -582,7 +586,9 @@ class ModelTpServer:
|
|
582
586
|
req.check_finished()
|
583
587
|
|
584
588
|
if req.return_logprob:
|
585
|
-
req.decode_token_logprobs.append(
|
589
|
+
req.decode_token_logprobs.append(
|
590
|
+
(next_token_logprobs[i], next_token_id)
|
591
|
+
)
|
586
592
|
if req.top_logprobs_num > 0:
|
587
593
|
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
588
594
|
|
@@ -759,16 +765,27 @@ class ModelTpClient:
|
|
759
765
|
with ThreadPoolExecutor(self.tp_size) as executor:
|
760
766
|
# Launch model processes
|
761
767
|
if server_args.nnodes == 1:
|
762
|
-
self.procs = list(
|
763
|
-
|
764
|
-
|
765
|
-
|
768
|
+
self.procs = list(
|
769
|
+
executor.map(
|
770
|
+
lambda args: start_rpyc_service_process(*args),
|
771
|
+
[
|
772
|
+
(ModelTpService, p)
|
773
|
+
for p in model_port_args.model_tp_ports
|
774
|
+
],
|
775
|
+
)
|
776
|
+
)
|
766
777
|
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
767
778
|
else:
|
768
|
-
addrs = [
|
769
|
-
|
770
|
-
|
771
|
-
|
779
|
+
addrs = [
|
780
|
+
(ip, port)
|
781
|
+
for ip, port in zip(
|
782
|
+
model_port_args.model_tp_ips, model_port_args.model_tp_ports
|
783
|
+
)
|
784
|
+
]
|
785
|
+
|
786
|
+
self.model_services = list(
|
787
|
+
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
788
|
+
)
|
772
789
|
|
773
790
|
# Init model
|
774
791
|
def init_model(i):
|
@@ -11,7 +11,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
11
11
|
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
12
12
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
13
13
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
14
|
-
from sglang.utils import get_exception_traceback, graceful_registry
|
14
|
+
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
15
15
|
|
16
16
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
17
17
|
|
@@ -57,6 +57,8 @@ class DetokenizerManager:
|
|
57
57
|
output_strs = []
|
58
58
|
for i in range(len(recv_obj.rids)):
|
59
59
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
60
|
+
if recv_obj.finished_reason[i] is None:
|
61
|
+
new_text = find_printable_text(new_text)
|
60
62
|
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
61
63
|
|
62
64
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
@@ -67,7 +69,7 @@ class DetokenizerManager:
|
|
67
69
|
self.send_to_tokenizer.send_pyobj(
|
68
70
|
BatchStrOut(
|
69
71
|
rids=recv_obj.rids,
|
70
|
-
|
72
|
+
output_strs=output_strs,
|
71
73
|
meta_info=recv_obj.meta_info,
|
72
74
|
finished_reason=recv_obj.finished_reason,
|
73
75
|
)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -316,7 +316,7 @@ class TokenizerManager:
|
|
316
316
|
|
317
317
|
recv_obj.meta_info[i]["id"] = rid
|
318
318
|
out_dict = {
|
319
|
-
"text": recv_obj.
|
319
|
+
"text": recv_obj.output_strs[i],
|
320
320
|
"meta_info": recv_obj.meta_info[i],
|
321
321
|
}
|
322
322
|
state.out_list.append(out_dict)
|
@@ -333,17 +333,18 @@ class TokenizerManager:
|
|
333
333
|
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
334
334
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
335
335
|
)
|
336
|
-
|
337
|
-
|
338
|
-
"
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
"
|
344
|
-
|
345
|
-
|
346
|
-
|
336
|
+
|
337
|
+
if top_logprobs_num > 0:
|
338
|
+
ret["meta_info"][
|
339
|
+
"prefill_top_logprobs"
|
340
|
+
] = self.detokenize_top_logprobs_tokens(
|
341
|
+
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
342
|
+
)
|
343
|
+
ret["meta_info"][
|
344
|
+
"decode_top_logprobs"
|
345
|
+
] = self.detokenize_top_logprobs_tokens(
|
346
|
+
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
347
|
+
)
|
347
348
|
return ret
|
348
349
|
|
349
350
|
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
@@ -383,7 +384,7 @@ def get_pixel_values(
|
|
383
384
|
try:
|
384
385
|
processor = processor or global_processor
|
385
386
|
image, image_size = load_image(image_data)
|
386
|
-
if image_size
|
387
|
+
if image_size is not None:
|
387
388
|
image_hash = hash(image_data)
|
388
389
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
389
390
|
for _ in range(len(pixel_values)):
|
sglang/srt/model_config.py
CHANGED
@@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
115
115
|
"""Get the "sub" config relevant to llm for multi modal models.
|
116
116
|
No op for pure text models.
|
117
117
|
"""
|
118
|
+
class_name = config.architectures[0]
|
119
|
+
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
120
|
+
# We support non-hf version of llava models, so we do not want to
|
121
|
+
# read the wrong values from the unused default text_config.
|
122
|
+
return config
|
123
|
+
|
118
124
|
if hasattr(config, "text_config"):
|
119
125
|
# The code operates under the assumption that text_config should have
|
120
126
|
# `num_attention_heads` (among others). Assert here to fail early
|