sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,8 @@
|
|
1
|
+
"""Logits processing."""
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import List, Union
|
5
|
+
|
1
6
|
import torch
|
2
7
|
from torch import nn
|
3
8
|
from vllm.distributed import (
|
@@ -8,6 +13,45 @@ from vllm.distributed import (
|
|
8
13
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
9
14
|
|
10
15
|
|
16
|
+
@dataclasses.dataclass
|
17
|
+
class LogitProcessorOutput:
|
18
|
+
# The logits of the next tokens. shape: [#seq, vocab_size]
|
19
|
+
next_token_logits: torch.Tensor
|
20
|
+
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
21
|
+
next_token_logprobs: torch.Tensor
|
22
|
+
|
23
|
+
# The normlaized logprobs of prompts. shape: [#seq]
|
24
|
+
normalized_prompt_logprobs: torch.Tensor
|
25
|
+
# The logprobs of prefill tokens. shape: [#token, vocab_size]
|
26
|
+
prefill_token_logprobs: torch.Tensor
|
27
|
+
|
28
|
+
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
29
|
+
prefill_top_logprobs: List
|
30
|
+
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
31
|
+
decode_top_logprobs: List
|
32
|
+
|
33
|
+
|
34
|
+
@dataclasses.dataclass
|
35
|
+
class LogitsMetadata:
|
36
|
+
forward_mode: ForwardMode
|
37
|
+
extend_seq_lens: torch.Tensor
|
38
|
+
extend_start_loc: torch.Tensor
|
39
|
+
|
40
|
+
# For logprobs
|
41
|
+
return_logprob: bool
|
42
|
+
top_logprobs_nums: List[int]
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def from_input_metadata(cls, input_metadata: InputMetadata):
|
46
|
+
return cls(
|
47
|
+
forward_mode=input_metadata.forward_mode,
|
48
|
+
extend_seq_lens=input_metadata.extend_seq_lens,
|
49
|
+
extend_start_loc=input_metadata.extend_start_loc,
|
50
|
+
return_logprob=input_metadata.return_logprob,
|
51
|
+
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
52
|
+
)
|
53
|
+
|
54
|
+
|
11
55
|
class LogitsProcessor(nn.Module):
|
12
56
|
def __init__(self, config):
|
13
57
|
super().__init__()
|
@@ -15,14 +59,14 @@ class LogitsProcessor(nn.Module):
|
|
15
59
|
self.tp_size = get_tensor_model_parallel_world_size()
|
16
60
|
|
17
61
|
def _get_normalized_prompt_logprobs(
|
18
|
-
self, prefill_token_logprobs,
|
62
|
+
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
|
19
63
|
):
|
20
64
|
logprobs_cumsum = torch.cumsum(
|
21
65
|
prefill_token_logprobs, dim=0, dtype=torch.float32
|
22
66
|
)
|
23
67
|
|
24
|
-
start =
|
25
|
-
end = start +
|
68
|
+
start = logits_metadata.extend_start_loc.clone()
|
69
|
+
end = start + logits_metadata.extend_seq_lens - 2
|
26
70
|
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
27
71
|
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
28
72
|
sum_logp = (
|
@@ -31,16 +75,17 @@ class LogitsProcessor(nn.Module):
|
|
31
75
|
+ prefill_token_logprobs[start]
|
32
76
|
)
|
33
77
|
normalized_prompt_logprobs = sum_logp / (
|
34
|
-
(
|
78
|
+
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
35
79
|
)
|
36
80
|
|
37
81
|
return normalized_prompt_logprobs
|
38
82
|
|
39
|
-
def _get_top_logprobs(self, all_logprobs,
|
40
|
-
|
83
|
+
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
|
84
|
+
# TODO: vectorize the code below
|
85
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
41
86
|
decode_top_logprobs = []
|
42
87
|
for i in range(all_logprobs.shape[0]):
|
43
|
-
k =
|
88
|
+
k = logits_metadata.top_logprobs_nums[i]
|
44
89
|
t = all_logprobs[i].topk(k)
|
45
90
|
v_cpu = t.values.tolist()
|
46
91
|
p_cpu = t.indices.tolist()
|
@@ -49,14 +94,13 @@ class LogitsProcessor(nn.Module):
|
|
49
94
|
else:
|
50
95
|
prefill_top_logprobs, decode_top_logprobs = [], []
|
51
96
|
pt = 0
|
52
|
-
|
53
|
-
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
97
|
+
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
54
98
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
55
99
|
if extend_seq_len == 0:
|
56
100
|
prefill_top_logprobs.append([])
|
57
101
|
decode_top_logprobs.append([])
|
58
102
|
continue
|
59
|
-
k =
|
103
|
+
k = logits_metadata.top_logprobs_nums[i]
|
60
104
|
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
61
105
|
vs_cpu = t.values.tolist()
|
62
106
|
ps_cpu = t.indices.tolist()
|
@@ -68,19 +112,26 @@ class LogitsProcessor(nn.Module):
|
|
68
112
|
|
69
113
|
return prefill_top_logprobs, decode_top_logprobs
|
70
114
|
|
71
|
-
def forward(
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
115
|
+
def forward(
|
116
|
+
self,
|
117
|
+
input_ids,
|
118
|
+
hidden_states,
|
119
|
+
weight,
|
120
|
+
logits_metadata: Union[LogitsMetadata, InputMetadata],
|
121
|
+
):
|
122
|
+
if isinstance(logits_metadata, InputMetadata):
|
123
|
+
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
|
124
|
+
assert isinstance(logits_metadata, LogitsMetadata)
|
79
125
|
|
80
|
-
# Get the last hidden states and last logits
|
81
|
-
if
|
126
|
+
# Get the last hidden states and last logits for the next token prediction
|
127
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
128
|
+
last_index = None
|
82
129
|
last_hidden = hidden_states
|
83
130
|
else:
|
131
|
+
last_index = (
|
132
|
+
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
133
|
+
- 1
|
134
|
+
)
|
84
135
|
last_hidden = hidden_states[last_index]
|
85
136
|
|
86
137
|
last_logits = torch.matmul(last_hidden, weight.T)
|
@@ -88,13 +139,24 @@ class LogitsProcessor(nn.Module):
|
|
88
139
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
89
140
|
last_logits = last_logits[:, : self.config.vocab_size]
|
90
141
|
|
142
|
+
if hasattr(self.config, "final_logit_softcapping"):
|
143
|
+
last_logits /= self.config.final_logit_softcapping
|
144
|
+
last_logits = torch.tanh(last_logits)
|
145
|
+
last_logits *= self.config.final_logit_softcapping
|
146
|
+
|
91
147
|
# Return only last_logits if logprob is not requested
|
92
|
-
if not
|
93
|
-
|
94
|
-
|
148
|
+
if not logits_metadata.return_logprob:
|
149
|
+
return LogitProcessorOutput(
|
150
|
+
next_token_logits=last_logits,
|
151
|
+
next_token_logprobs=None,
|
152
|
+
normalized_prompt_logprobs=None,
|
153
|
+
prefill_token_logprobs=None,
|
154
|
+
prefill_top_logprobs=None,
|
155
|
+
decode_top_logprobs=None,
|
156
|
+
)
|
95
157
|
else:
|
96
158
|
# When logprob is requested, compute the logits for all tokens.
|
97
|
-
if
|
159
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
98
160
|
all_logits = last_logits
|
99
161
|
else:
|
100
162
|
all_logits = torch.matmul(hidden_states, weight.T)
|
@@ -106,25 +168,25 @@ class LogitsProcessor(nn.Module):
|
|
106
168
|
del all_logits
|
107
169
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
108
170
|
|
109
|
-
|
171
|
+
# Get the logprob of top-k tokens
|
172
|
+
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
|
110
173
|
if return_top_logprob:
|
111
174
|
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
112
|
-
all_logprobs,
|
175
|
+
all_logprobs, logits_metadata
|
113
176
|
)
|
114
177
|
else:
|
115
178
|
prefill_top_logprobs = decode_top_logprobs = None
|
116
179
|
|
117
|
-
if
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
None,
|
122
|
-
None,
|
123
|
-
|
124
|
-
|
180
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
181
|
+
return LogitProcessorOutput(
|
182
|
+
next_token_logits=last_logits,
|
183
|
+
next_token_logprobs=all_logprobs,
|
184
|
+
normalized_prompt_logprobs=None,
|
185
|
+
prefill_token_logprobs=None,
|
186
|
+
prefill_top_logprobs=None,
|
187
|
+
decode_top_logprobs=decode_top_logprobs,
|
125
188
|
)
|
126
189
|
else:
|
127
|
-
# Compute the logprobs for the last token of each request.
|
128
190
|
last_logprobs = all_logprobs[last_index]
|
129
191
|
|
130
192
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
@@ -135,14 +197,16 @@ class LogitsProcessor(nn.Module):
|
|
135
197
|
]
|
136
198
|
|
137
199
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
138
|
-
prefill_token_logprobs,
|
200
|
+
prefill_token_logprobs, logits_metadata
|
139
201
|
)
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
202
|
+
|
203
|
+
return LogitProcessorOutput(
|
204
|
+
next_token_logits=last_logits,
|
205
|
+
next_token_logprobs=last_logprobs,
|
206
|
+
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
207
|
+
prefill_token_logprobs=prefill_token_logprobs,
|
208
|
+
prefill_top_logprobs=prefill_top_logprobs,
|
209
|
+
decode_top_logprobs=decode_top_logprobs,
|
146
210
|
)
|
147
211
|
|
148
212
|
|
@@ -1,52 +1,52 @@
|
|
1
|
-
|
1
|
+
"""Radix attention."""
|
2
|
+
|
2
3
|
import numpy as np
|
4
|
+
import torch
|
5
|
+
from flashinfer.cascade import merge_state
|
3
6
|
from torch import nn
|
4
7
|
|
5
|
-
from sglang.
|
8
|
+
from sglang.global_config import global_config
|
6
9
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
7
10
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
8
11
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
9
12
|
|
10
13
|
|
11
14
|
class RadixAttention(nn.Module):
|
12
|
-
def __init__(
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
num_heads: int,
|
18
|
+
head_dim: int,
|
19
|
+
scaling: float,
|
20
|
+
num_kv_heads: int,
|
21
|
+
layer_id: int,
|
22
|
+
logit_cap: int = -1,
|
23
|
+
):
|
13
24
|
super().__init__()
|
14
25
|
self.tp_q_head_num = num_heads
|
15
26
|
self.tp_k_head_num = num_kv_heads
|
16
27
|
self.tp_v_head_num = num_kv_heads
|
17
28
|
self.head_dim = head_dim
|
29
|
+
self.scaling = scaling
|
18
30
|
self.layer_id = layer_id
|
19
|
-
self.logit_cap = logit_cap
|
20
|
-
|
21
|
-
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
22
31
|
|
23
32
|
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
24
33
|
|
25
|
-
if global_server_args_dict.get("
|
34
|
+
if not global_server_args_dict.get("disable_flashinfer", False):
|
26
35
|
self.prefill_forward = self.prefill_forward_flashinfer
|
27
36
|
self.extend_forward = self.prefill_forward_flashinfer
|
28
37
|
self.decode_forward = self.decode_forward_flashinfer
|
38
|
+
# flashinfer now accepts float logit_cap argument
|
39
|
+
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
29
40
|
else:
|
30
41
|
self.prefill_forward = self.prefill_forward_triton
|
31
42
|
self.extend_forward = self.extend_forward_triton
|
32
43
|
self.decode_forward = self.decode_forward_triton
|
44
|
+
self.logit_cap = logit_cap if logit_cap is not None else 0
|
33
45
|
|
34
46
|
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
q.view(-1, self.tp_q_head_num, self.head_dim),
|
39
|
-
k,
|
40
|
-
v,
|
41
|
-
o.view(-1, self.tp_q_head_num, self.head_dim),
|
42
|
-
input_metadata.start_loc,
|
43
|
-
input_metadata.seq_lens,
|
44
|
-
input_metadata.max_seq_len,
|
45
|
-
self.logit_cap,
|
46
|
-
)
|
47
|
-
self.store_kv_cache(k, v, input_metadata)
|
48
|
-
|
49
|
-
return o
|
47
|
+
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
|
48
|
+
# See the extend_forward_xxx functions.
|
49
|
+
raise NotImplementedError()
|
50
50
|
|
51
51
|
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
52
52
|
o = torch.empty_like(q)
|
@@ -67,7 +67,8 @@ class RadixAttention(nn.Module):
|
|
67
67
|
input_metadata.extend_seq_lens,
|
68
68
|
input_metadata.max_seq_len,
|
69
69
|
input_metadata.max_extend_len,
|
70
|
-
self.
|
70
|
+
sm_scale=self.scaling,
|
71
|
+
logit_cap=self.logit_cap,
|
71
72
|
)
|
72
73
|
|
73
74
|
return o
|
@@ -88,27 +89,50 @@ class RadixAttention(nn.Module):
|
|
88
89
|
input_metadata.max_seq_len,
|
89
90
|
input_metadata.other_kv_index,
|
90
91
|
input_metadata.total_num_tokens,
|
91
|
-
self.
|
92
|
+
sm_scale=self.scaling,
|
93
|
+
logit_cap=self.logit_cap,
|
92
94
|
)
|
93
95
|
|
94
96
|
return o
|
95
97
|
|
96
98
|
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
97
|
-
|
98
|
-
|
99
|
-
o = input_metadata.prefill_wrapper.forward(
|
99
|
+
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
100
100
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
101
|
-
|
101
|
+
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
102
|
+
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
103
|
+
causal=True,
|
104
|
+
sm_scale=self.scaling,
|
105
|
+
logits_soft_cap=self.logit_cap,
|
102
106
|
)
|
103
107
|
|
108
|
+
if input_metadata.no_prefix:
|
109
|
+
o = o1
|
110
|
+
else:
|
111
|
+
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
112
|
+
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
113
|
+
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
114
|
+
causal=False,
|
115
|
+
sm_scale=self.scaling,
|
116
|
+
logits_soft_cap=self.logit_cap,
|
117
|
+
)
|
118
|
+
|
119
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
120
|
+
|
121
|
+
self.store_kv_cache(k, v, input_metadata)
|
122
|
+
|
123
|
+
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
124
|
+
torch.cuda.synchronize()
|
125
|
+
|
104
126
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
105
127
|
|
106
128
|
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
107
129
|
self.store_kv_cache(k, v, input_metadata)
|
108
130
|
|
109
|
-
o = input_metadata.
|
131
|
+
o = input_metadata.flashinfer_decode_wrapper.forward(
|
110
132
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
111
133
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
134
|
+
sm_scale=self.scaling,
|
135
|
+
logits_soft_cap=self.logit_cap,
|
112
136
|
)
|
113
137
|
|
114
138
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
@@ -176,6 +176,7 @@ def _token_att_m_fwd(
|
|
176
176
|
B_Start_Loc,
|
177
177
|
B_Seqlen,
|
178
178
|
max_len_in_batch,
|
179
|
+
sm_scale,
|
179
180
|
logit_cap,
|
180
181
|
):
|
181
182
|
BLOCK = 32
|
@@ -183,7 +184,6 @@ def _token_att_m_fwd(
|
|
183
184
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
184
185
|
assert Lq == Lk
|
185
186
|
assert Lk in {16, 32, 64, 128, 256}
|
186
|
-
sm_scale = 1.0 / (Lk**0.5)
|
187
187
|
|
188
188
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
189
189
|
|
@@ -317,6 +317,7 @@ def token_attention_fwd(
|
|
317
317
|
max_len_in_batch,
|
318
318
|
other_kv_index,
|
319
319
|
total_num_tokens,
|
320
|
+
sm_scale=None,
|
320
321
|
logit_cap=-1,
|
321
322
|
att_m=None,
|
322
323
|
):
|
@@ -324,6 +325,7 @@ def token_attention_fwd(
|
|
324
325
|
att_m = torch.empty(
|
325
326
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
326
327
|
)
|
328
|
+
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
|
327
329
|
|
328
330
|
_token_att_m_fwd(
|
329
331
|
q,
|
@@ -334,6 +336,7 @@ def token_attention_fwd(
|
|
334
336
|
b_start_loc,
|
335
337
|
b_seq_len,
|
336
338
|
max_len_in_batch,
|
339
|
+
sm_scale,
|
337
340
|
logit_cap,
|
338
341
|
)
|
339
342
|
_token_softmax_reducev_fwd(
|
@@ -1,9 +1,10 @@
|
|
1
1
|
"""A data parallel worker thread."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import logging
|
4
5
|
import queue
|
5
6
|
import threading
|
6
|
-
from typing import
|
7
|
+
from typing import Callable, List
|
7
8
|
|
8
9
|
import uvloop
|
9
10
|
import zmq
|
@@ -69,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
|
|
69
70
|
|
70
71
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
71
72
|
if len(out_pyobjs) != 0:
|
72
|
-
has_finished = any(
|
73
|
+
has_finished = any(
|
74
|
+
[obj.finished_reason is not None for obj in out_pyobjs]
|
75
|
+
)
|
73
76
|
if has_finished:
|
74
77
|
await asyncio.sleep(self.request_dependency_delay)
|
75
78
|
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
@@ -107,4 +110,4 @@ def start_data_parallel_worker(
|
|
107
110
|
step_func=model_tp_client.step,
|
108
111
|
)
|
109
112
|
worker_thread.start()
|
110
|
-
return worker_thread
|
113
|
+
return worker_thread
|