sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,8 @@
|
|
1
|
+
"""Logits processing."""
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import List
|
5
|
+
|
1
6
|
import torch
|
2
7
|
from torch import nn
|
3
8
|
from vllm.distributed import (
|
@@ -5,7 +10,25 @@ from vllm.distributed import (
|
|
5
10
|
tensor_model_parallel_all_gather,
|
6
11
|
)
|
7
12
|
|
8
|
-
from sglang.srt.managers.
|
13
|
+
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
14
|
+
|
15
|
+
|
16
|
+
@dataclasses.dataclass
|
17
|
+
class LogitProcessorOutput:
|
18
|
+
# The logits of the next tokens. shape: [#seq, vocab_size]
|
19
|
+
next_token_logits: torch.Tensor
|
20
|
+
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
21
|
+
next_token_logprobs: torch.Tensor
|
22
|
+
|
23
|
+
# The normlaized logprobs of prompts. shape: [#seq]
|
24
|
+
normalized_prompt_logprobs: torch.Tensor
|
25
|
+
# The logprobs of prefill tokens. shape: [#token, vocab_size]
|
26
|
+
prefill_token_logprobs: torch.Tensor
|
27
|
+
|
28
|
+
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
29
|
+
prefill_top_logprobs: List
|
30
|
+
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
31
|
+
decode_top_logprobs: List
|
9
32
|
|
10
33
|
|
11
34
|
class LogitsProcessor(nn.Module):
|
@@ -37,6 +60,7 @@ class LogitsProcessor(nn.Module):
|
|
37
60
|
return normalized_prompt_logprobs
|
38
61
|
|
39
62
|
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
63
|
+
# TODO: vectorize the code below
|
40
64
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
41
65
|
decode_top_logprobs = []
|
42
66
|
for i in range(all_logprobs.shape[0]):
|
@@ -49,37 +73,34 @@ class LogitsProcessor(nn.Module):
|
|
49
73
|
else:
|
50
74
|
prefill_top_logprobs, decode_top_logprobs = [], []
|
51
75
|
pt = 0
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
if extend_seq_lens_cpu[i] == 0:
|
76
|
+
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
77
|
+
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
78
|
+
if extend_seq_len == 0:
|
56
79
|
prefill_top_logprobs.append([])
|
57
80
|
decode_top_logprobs.append([])
|
58
81
|
continue
|
59
82
|
k = input_metadata.top_logprobs_nums[i]
|
60
|
-
t = all_logprobs[pt : pt +
|
83
|
+
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
61
84
|
vs_cpu = t.values.tolist()
|
62
85
|
ps_cpu = t.indices.tolist()
|
63
86
|
prefill_top_logprobs.append(
|
64
87
|
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
65
88
|
)
|
66
89
|
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
67
|
-
pt +=
|
90
|
+
pt += extend_seq_len
|
91
|
+
|
68
92
|
return prefill_top_logprobs, decode_top_logprobs
|
69
93
|
|
70
94
|
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
71
|
-
# Get last
|
72
|
-
|
73
|
-
|
95
|
+
# Get the last hidden states and last logits for the next token prediction
|
96
|
+
if input_metadata.forward_mode == ForwardMode.DECODE:
|
97
|
+
last_index = None
|
98
|
+
last_hidden = hidden_states
|
99
|
+
else:
|
74
100
|
last_index = (
|
75
101
|
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
76
102
|
- 1
|
77
103
|
)
|
78
|
-
|
79
|
-
# Get the last hidden states and last logits
|
80
|
-
if input_metadata.forward_mode == ForwardMode.DECODE:
|
81
|
-
last_hidden = hidden_states
|
82
|
-
else:
|
83
104
|
last_hidden = hidden_states[last_index]
|
84
105
|
|
85
106
|
last_logits = torch.matmul(last_hidden, weight.T)
|
@@ -89,8 +110,14 @@ class LogitsProcessor(nn.Module):
|
|
89
110
|
|
90
111
|
# Return only last_logits if logprob is not requested
|
91
112
|
if not input_metadata.return_logprob:
|
92
|
-
|
93
|
-
|
113
|
+
return LogitProcessorOutput(
|
114
|
+
next_token_logits=last_logits,
|
115
|
+
next_token_logprobs=None,
|
116
|
+
normalized_prompt_logprobs=None,
|
117
|
+
prefill_token_logprobs=None,
|
118
|
+
prefill_top_logprobs=None,
|
119
|
+
decode_top_logprobs=None,
|
120
|
+
)
|
94
121
|
else:
|
95
122
|
# When logprob is requested, compute the logits for all tokens.
|
96
123
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
@@ -105,6 +132,7 @@ class LogitsProcessor(nn.Module):
|
|
105
132
|
del all_logits
|
106
133
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
107
134
|
|
135
|
+
# Get the logprob of top-k tokens
|
108
136
|
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
109
137
|
if return_top_logprob:
|
110
138
|
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
@@ -114,16 +142,15 @@ class LogitsProcessor(nn.Module):
|
|
114
142
|
prefill_top_logprobs = decode_top_logprobs = None
|
115
143
|
|
116
144
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
None,
|
121
|
-
None,
|
122
|
-
|
123
|
-
|
145
|
+
return LogitProcessorOutput(
|
146
|
+
next_token_logits=last_logits,
|
147
|
+
next_token_logprobs=all_logprobs,
|
148
|
+
normalized_prompt_logprobs=None,
|
149
|
+
prefill_token_logprobs=None,
|
150
|
+
prefill_top_logprobs=None,
|
151
|
+
decode_top_logprobs=decode_top_logprobs,
|
124
152
|
)
|
125
153
|
else:
|
126
|
-
# Compute the logprobs for the last token of each request.
|
127
154
|
last_logprobs = all_logprobs[last_index]
|
128
155
|
|
129
156
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
@@ -136,16 +163,18 @@ class LogitsProcessor(nn.Module):
|
|
136
163
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
137
164
|
prefill_token_logprobs, input_metadata
|
138
165
|
)
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
166
|
+
|
167
|
+
return LogitProcessorOutput(
|
168
|
+
next_token_logits=last_logits,
|
169
|
+
next_token_logprobs=last_logprobs,
|
170
|
+
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
171
|
+
prefill_token_logprobs=prefill_token_logprobs,
|
172
|
+
prefill_top_logprobs=prefill_top_logprobs,
|
173
|
+
decode_top_logprobs=decode_top_logprobs,
|
145
174
|
)
|
146
175
|
|
147
176
|
|
148
|
-
|
177
|
+
def test():
|
149
178
|
all_logprobs = torch.tensor(
|
150
179
|
# s s s
|
151
180
|
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
@@ -173,3 +202,7 @@ if __name__ == "__main__":
|
|
173
202
|
print("start", start)
|
174
203
|
print("end", end)
|
175
204
|
print("sum_logp", sum_logp)
|
205
|
+
|
206
|
+
|
207
|
+
if __name__ == "__main__":
|
208
|
+
test()
|
@@ -1,14 +1,21 @@
|
|
1
|
+
"""Radix attention."""
|
2
|
+
|
3
|
+
import numpy as np
|
1
4
|
import torch
|
2
5
|
from torch import nn
|
3
6
|
|
7
|
+
from sglang.global_config import global_config
|
4
8
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
5
9
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
6
10
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
7
|
-
from sglang.srt.managers.
|
11
|
+
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
8
12
|
|
9
13
|
|
10
14
|
class RadixAttention(nn.Module):
|
11
|
-
def __init__(
|
15
|
+
def __init__(
|
16
|
+
self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
|
17
|
+
layer_id: int, logit_cap: int = -1
|
18
|
+
):
|
12
19
|
super().__init__()
|
13
20
|
self.tp_q_head_num = num_heads
|
14
21
|
self.tp_k_head_num = num_kv_heads
|
@@ -16,16 +23,21 @@ class RadixAttention(nn.Module):
|
|
16
23
|
self.head_dim = head_dim
|
17
24
|
self.layer_id = layer_id
|
18
25
|
|
19
|
-
|
26
|
+
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
27
|
+
|
28
|
+
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
20
29
|
|
21
|
-
if global_server_args_dict.get("
|
30
|
+
if not global_server_args_dict.get("disable_flashinfer", False):
|
22
31
|
self.prefill_forward = self.prefill_forward_flashinfer
|
23
32
|
self.extend_forward = self.prefill_forward_flashinfer
|
24
33
|
self.decode_forward = self.decode_forward_flashinfer
|
34
|
+
# flashinfer now accepts float logit_cap argument
|
35
|
+
self.logit_cap = logit_cap if logit_cap > 0 else 0
|
25
36
|
else:
|
26
37
|
self.prefill_forward = self.prefill_forward_triton
|
27
38
|
self.extend_forward = self.extend_forward_triton
|
28
39
|
self.decode_forward = self.decode_forward_triton
|
40
|
+
self.logit_cap = logit_cap
|
29
41
|
|
30
42
|
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
31
43
|
o = torch.empty_like(q)
|
@@ -38,6 +50,7 @@ class RadixAttention(nn.Module):
|
|
38
50
|
input_metadata.start_loc,
|
39
51
|
input_metadata.seq_lens,
|
40
52
|
input_metadata.max_seq_len,
|
53
|
+
self.logit_cap,
|
41
54
|
)
|
42
55
|
self.store_kv_cache(k, v, input_metadata)
|
43
56
|
|
@@ -62,6 +75,7 @@ class RadixAttention(nn.Module):
|
|
62
75
|
input_metadata.extend_seq_lens,
|
63
76
|
input_metadata.max_seq_len,
|
64
77
|
input_metadata.max_extend_len,
|
78
|
+
self.logit_cap,
|
65
79
|
)
|
66
80
|
|
67
81
|
return o
|
@@ -82,6 +96,7 @@ class RadixAttention(nn.Module):
|
|
82
96
|
input_metadata.max_seq_len,
|
83
97
|
input_metadata.other_kv_index,
|
84
98
|
input_metadata.total_num_tokens,
|
99
|
+
self.logit_cap,
|
85
100
|
)
|
86
101
|
|
87
102
|
return o
|
@@ -89,19 +104,38 @@ class RadixAttention(nn.Module):
|
|
89
104
|
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
90
105
|
self.store_kv_cache(k, v, input_metadata)
|
91
106
|
|
92
|
-
|
107
|
+
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
93
108
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
94
|
-
|
109
|
+
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
110
|
+
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
111
|
+
logits_soft_cap=self.logit_cap,
|
95
112
|
)
|
96
113
|
|
114
|
+
if input_metadata.no_prefix:
|
115
|
+
o = o1
|
116
|
+
else:
|
117
|
+
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
118
|
+
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
119
|
+
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
120
|
+
causal=False,
|
121
|
+
logits_soft_cap=self.logit_cap,
|
122
|
+
)
|
123
|
+
|
124
|
+
from flashinfer.cascade import merge_state
|
125
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
126
|
+
|
127
|
+
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
128
|
+
torch.cuda.synchronize()
|
129
|
+
|
97
130
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
98
131
|
|
99
132
|
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
100
133
|
self.store_kv_cache(k, v, input_metadata)
|
101
134
|
|
102
|
-
o = input_metadata.
|
135
|
+
o = input_metadata.flashinfer_decode_wrapper.forward(
|
103
136
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
104
137
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
138
|
+
logits_soft_cap=self.logit_cap,
|
105
139
|
)
|
106
140
|
|
107
141
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from sglang.srt.managers.
|
8
|
+
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
9
9
|
from sglang.srt.utils import wrap_kernel_launcher
|
10
10
|
|
11
11
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
@@ -16,6 +16,12 @@ else:
|
|
16
16
|
REDUCE_TORCH_TYPE = torch.float16
|
17
17
|
|
18
18
|
|
19
|
+
@triton.jit
|
20
|
+
def tanh(x):
|
21
|
+
# Tanh is just a scaled sigmoid
|
22
|
+
return 2 * tl.sigmoid(2 * x) - 1
|
23
|
+
|
24
|
+
|
19
25
|
@triton.jit
|
20
26
|
def _fwd_kernel_stage1(
|
21
27
|
Q,
|
@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
|
|
35
41
|
kv_group_num: tl.constexpr,
|
36
42
|
BLOCK_DMODEL: tl.constexpr,
|
37
43
|
BLOCK_N: tl.constexpr,
|
44
|
+
logit_cap: tl.constexpr,
|
38
45
|
):
|
39
46
|
cur_batch = tl.program_id(0)
|
40
47
|
cur_head = tl.program_id(1)
|
@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
|
|
77
84
|
).to(REDUCE_TRITON_TYPE)
|
78
85
|
att_value = tl.sum(q[None, :] * k, 1)
|
79
86
|
att_value *= sm_scale
|
87
|
+
|
88
|
+
if logit_cap > 0:
|
89
|
+
att_value = logit_cap * tanh(att_value / logit_cap)
|
90
|
+
|
80
91
|
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
81
92
|
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
|
82
93
|
|
@@ -165,6 +176,7 @@ def _token_att_m_fwd(
|
|
165
176
|
B_Start_Loc,
|
166
177
|
B_Seqlen,
|
167
178
|
max_len_in_batch,
|
179
|
+
logit_cap,
|
168
180
|
):
|
169
181
|
BLOCK = 32
|
170
182
|
# shape constraints
|
@@ -223,6 +235,7 @@ def _token_att_m_fwd(
|
|
223
235
|
kv_group_num=kv_group_num,
|
224
236
|
BLOCK_DMODEL=Lk,
|
225
237
|
BLOCK_N=BLOCK,
|
238
|
+
logit_cap=logit_cap,
|
226
239
|
num_warps=num_warps,
|
227
240
|
num_stages=1,
|
228
241
|
)
|
@@ -304,6 +317,7 @@ def token_attention_fwd(
|
|
304
317
|
max_len_in_batch,
|
305
318
|
other_kv_index,
|
306
319
|
total_num_tokens,
|
320
|
+
logit_cap=-1,
|
307
321
|
att_m=None,
|
308
322
|
):
|
309
323
|
if att_m is None:
|
@@ -320,6 +334,7 @@ def token_attention_fwd(
|
|
320
334
|
b_start_loc,
|
321
335
|
b_seq_len,
|
322
336
|
max_len_in_batch,
|
337
|
+
logit_cap,
|
323
338
|
)
|
324
339
|
_token_softmax_reducev_fwd(
|
325
340
|
att_m,
|
@@ -0,0 +1,113 @@
|
|
1
|
+
"""A data parallel worker thread."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
import queue
|
6
|
+
import threading
|
7
|
+
from typing import Callable, List
|
8
|
+
|
9
|
+
import uvloop
|
10
|
+
import zmq
|
11
|
+
|
12
|
+
from sglang.global_config import global_config
|
13
|
+
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
14
|
+
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
15
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
16
|
+
from sglang.srt.utils import kill_parent_process
|
17
|
+
from sglang.utils import get_exception_traceback
|
18
|
+
|
19
|
+
logger = logging.getLogger("srt.controller")
|
20
|
+
CHECKING_INTERVAL = 5
|
21
|
+
|
22
|
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
23
|
+
|
24
|
+
|
25
|
+
class DataParallelWorkerThread(threading.Thread):
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
worker_id: int,
|
29
|
+
request_queue: queue.Queue,
|
30
|
+
detokenizer_port: int,
|
31
|
+
step_func: Callable,
|
32
|
+
):
|
33
|
+
super(DataParallelWorkerThread, self).__init__()
|
34
|
+
self.worker_id = worker_id
|
35
|
+
self.request_queue = request_queue
|
36
|
+
self.liveness = True
|
37
|
+
self.request_dependency_delay = global_config.request_dependency_delay
|
38
|
+
|
39
|
+
context = zmq.asyncio.Context()
|
40
|
+
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
41
|
+
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
|
42
|
+
|
43
|
+
self.step = step_func
|
44
|
+
|
45
|
+
async def loop_for_forward(self):
|
46
|
+
while self.liveness:
|
47
|
+
requests = []
|
48
|
+
while not self.request_queue.empty():
|
49
|
+
requests.append(self.request_queue.get())
|
50
|
+
|
51
|
+
out_pyobjs: List[BatchTokenIDOut] = []
|
52
|
+
try:
|
53
|
+
out_pyobjs = await self.step(requests)
|
54
|
+
except Exception:
|
55
|
+
for r in requests:
|
56
|
+
self.request_queue.put(r)
|
57
|
+
logger.error(
|
58
|
+
f"Worker thread {self.worker_id}: "
|
59
|
+
f"failed to get back from Model Server\n"
|
60
|
+
f"{get_exception_traceback()}"
|
61
|
+
)
|
62
|
+
self.liveness = False
|
63
|
+
# Crash the whole server when there are any errors.
|
64
|
+
# TODO(lianmin): make this an option.
|
65
|
+
kill_parent_process()
|
66
|
+
return
|
67
|
+
|
68
|
+
for obj in out_pyobjs:
|
69
|
+
self.send_to_detokenizer.send_pyobj(obj)
|
70
|
+
|
71
|
+
# async sleep for receiving the subsequent request and avoiding cache miss
|
72
|
+
if len(out_pyobjs) != 0:
|
73
|
+
has_finished = any(
|
74
|
+
[obj.finished_reason is not None for obj in out_pyobjs]
|
75
|
+
)
|
76
|
+
if has_finished:
|
77
|
+
await asyncio.sleep(self.request_dependency_delay)
|
78
|
+
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
79
|
+
|
80
|
+
async def monitoring(self):
|
81
|
+
while True:
|
82
|
+
await asyncio.sleep(CHECKING_INTERVAL)
|
83
|
+
# can plug in monitoring logic here
|
84
|
+
|
85
|
+
def run(self):
|
86
|
+
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
|
87
|
+
loop = asyncio.new_event_loop()
|
88
|
+
asyncio.set_event_loop(loop)
|
89
|
+
loop.create_task(self.monitoring())
|
90
|
+
loop.run_until_complete(self.loop_for_forward())
|
91
|
+
|
92
|
+
|
93
|
+
def start_data_parallel_worker(
|
94
|
+
server_args: ServerArgs,
|
95
|
+
port_args: PortArgs,
|
96
|
+
model_overide_args,
|
97
|
+
gpu_ids: List[int],
|
98
|
+
worker_id: int,
|
99
|
+
):
|
100
|
+
model_tp_client = ModelTpClient(
|
101
|
+
gpu_ids,
|
102
|
+
server_args,
|
103
|
+
port_args.model_port_args[worker_id],
|
104
|
+
model_overide_args,
|
105
|
+
)
|
106
|
+
worker_thread = DataParallelWorkerThread(
|
107
|
+
worker_id=worker_id,
|
108
|
+
request_queue=queue.Queue(),
|
109
|
+
detokenizer_port=port_args.detokenizer_port,
|
110
|
+
step_func=model_tp_client.step,
|
111
|
+
)
|
112
|
+
worker_thread.start()
|
113
|
+
return worker_thread
|