sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,8 @@
1
+ """Logits processing."""
2
+
3
+ import dataclasses
4
+ from typing import List
5
+
1
6
  import torch
2
7
  from torch import nn
3
8
  from vllm.distributed import (
@@ -5,7 +10,25 @@ from vllm.distributed import (
5
10
  tensor_model_parallel_all_gather,
6
11
  )
7
12
 
8
- from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
13
+ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class LogitProcessorOutput:
18
+ # The logits of the next tokens. shape: [#seq, vocab_size]
19
+ next_token_logits: torch.Tensor
20
+ # The logprobs of the next tokens. shape: [#seq, vocab_size]
21
+ next_token_logprobs: torch.Tensor
22
+
23
+ # The normlaized logprobs of prompts. shape: [#seq]
24
+ normalized_prompt_logprobs: torch.Tensor
25
+ # The logprobs of prefill tokens. shape: [#token, vocab_size]
26
+ prefill_token_logprobs: torch.Tensor
27
+
28
+ # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
29
+ prefill_top_logprobs: List
30
+ # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
31
+ decode_top_logprobs: List
9
32
 
10
33
 
11
34
  class LogitsProcessor(nn.Module):
@@ -37,6 +60,7 @@ class LogitsProcessor(nn.Module):
37
60
  return normalized_prompt_logprobs
38
61
 
39
62
  def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
63
+ # TODO: vectorize the code below
40
64
  if input_metadata.forward_mode == ForwardMode.DECODE:
41
65
  decode_top_logprobs = []
42
66
  for i in range(all_logprobs.shape[0]):
@@ -49,37 +73,34 @@ class LogitsProcessor(nn.Module):
49
73
  else:
50
74
  prefill_top_logprobs, decode_top_logprobs = [], []
51
75
  pt = 0
52
- # NOTE: the GPU-CPU overhead can be reduced
53
- extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
54
- for i in range(len(extend_seq_lens_cpu)):
55
- if extend_seq_lens_cpu[i] == 0:
76
+ extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
77
+ for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
78
+ if extend_seq_len == 0:
56
79
  prefill_top_logprobs.append([])
57
80
  decode_top_logprobs.append([])
58
81
  continue
59
82
  k = input_metadata.top_logprobs_nums[i]
60
- t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
83
+ t = all_logprobs[pt : pt + extend_seq_len].topk(k)
61
84
  vs_cpu = t.values.tolist()
62
85
  ps_cpu = t.indices.tolist()
63
86
  prefill_top_logprobs.append(
64
87
  [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
65
88
  )
66
89
  decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
67
- pt += extend_seq_lens_cpu[i]
90
+ pt += extend_seq_len
91
+
68
92
  return prefill_top_logprobs, decode_top_logprobs
69
93
 
70
94
  def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
71
- # Get last index for next token prediction, except for DECODE mode.
72
- last_index = None
73
- if input_metadata.forward_mode != ForwardMode.DECODE:
95
+ # Get the last hidden states and last logits for the next token prediction
96
+ if input_metadata.forward_mode == ForwardMode.DECODE:
97
+ last_index = None
98
+ last_hidden = hidden_states
99
+ else:
74
100
  last_index = (
75
101
  torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
76
102
  - 1
77
103
  )
78
-
79
- # Get the last hidden states and last logits
80
- if input_metadata.forward_mode == ForwardMode.DECODE:
81
- last_hidden = hidden_states
82
- else:
83
104
  last_hidden = hidden_states[last_index]
84
105
 
85
106
  last_logits = torch.matmul(last_hidden, weight.T)
@@ -89,8 +110,14 @@ class LogitsProcessor(nn.Module):
89
110
 
90
111
  # Return only last_logits if logprob is not requested
91
112
  if not input_metadata.return_logprob:
92
- hidden_states = None
93
- return last_logits, (None, None, None, None, None)
113
+ return LogitProcessorOutput(
114
+ next_token_logits=last_logits,
115
+ next_token_logprobs=None,
116
+ normalized_prompt_logprobs=None,
117
+ prefill_token_logprobs=None,
118
+ prefill_top_logprobs=None,
119
+ decode_top_logprobs=None,
120
+ )
94
121
  else:
95
122
  # When logprob is requested, compute the logits for all tokens.
96
123
  if input_metadata.forward_mode == ForwardMode.DECODE:
@@ -105,6 +132,7 @@ class LogitsProcessor(nn.Module):
105
132
  del all_logits
106
133
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
107
134
 
135
+ # Get the logprob of top-k tokens
108
136
  return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
109
137
  if return_top_logprob:
110
138
  prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
@@ -114,16 +142,15 @@ class LogitsProcessor(nn.Module):
114
142
  prefill_top_logprobs = decode_top_logprobs = None
115
143
 
116
144
  if input_metadata.forward_mode == ForwardMode.DECODE:
117
- last_logprobs = all_logprobs
118
- return last_logits, (
119
- None,
120
- None,
121
- None,
122
- decode_top_logprobs,
123
- last_logprobs,
145
+ return LogitProcessorOutput(
146
+ next_token_logits=last_logits,
147
+ next_token_logprobs=all_logprobs,
148
+ normalized_prompt_logprobs=None,
149
+ prefill_token_logprobs=None,
150
+ prefill_top_logprobs=None,
151
+ decode_top_logprobs=decode_top_logprobs,
124
152
  )
125
153
  else:
126
- # Compute the logprobs for the last token of each request.
127
154
  last_logprobs = all_logprobs[last_index]
128
155
 
129
156
  # Compute the logprobs and normalized logprobs for the prefill tokens.
@@ -136,16 +163,18 @@ class LogitsProcessor(nn.Module):
136
163
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
137
164
  prefill_token_logprobs, input_metadata
138
165
  )
139
- return last_logits, (
140
- prefill_token_logprobs,
141
- normalized_prompt_logprobs,
142
- prefill_top_logprobs,
143
- decode_top_logprobs,
144
- last_logprobs,
166
+
167
+ return LogitProcessorOutput(
168
+ next_token_logits=last_logits,
169
+ next_token_logprobs=last_logprobs,
170
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
171
+ prefill_token_logprobs=prefill_token_logprobs,
172
+ prefill_top_logprobs=prefill_top_logprobs,
173
+ decode_top_logprobs=decode_top_logprobs,
145
174
  )
146
175
 
147
176
 
148
- if __name__ == "__main__":
177
+ def test():
149
178
  all_logprobs = torch.tensor(
150
179
  # s s s
151
180
  [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
@@ -173,3 +202,7 @@ if __name__ == "__main__":
173
202
  print("start", start)
174
203
  print("end", end)
175
204
  print("sum_logp", sum_logp)
205
+
206
+
207
+ if __name__ == "__main__":
208
+ test()
@@ -1,14 +1,21 @@
1
+ """Radix attention."""
2
+
3
+ import numpy as np
1
4
  import torch
2
5
  from torch import nn
3
6
 
7
+ from sglang.global_config import global_config
4
8
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
5
9
  from sglang.srt.layers.extend_attention import extend_attention_fwd
6
10
  from sglang.srt.layers.token_attention import token_attention_fwd
7
- from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
11
+ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
8
12
 
9
13
 
10
14
  class RadixAttention(nn.Module):
11
- def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
15
+ def __init__(
16
+ self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
17
+ layer_id: int, logit_cap: int = -1
18
+ ):
12
19
  super().__init__()
13
20
  self.tp_q_head_num = num_heads
14
21
  self.tp_k_head_num = num_kv_heads
@@ -16,16 +23,21 @@ class RadixAttention(nn.Module):
16
23
  self.head_dim = head_dim
17
24
  self.layer_id = layer_id
18
25
 
19
- from sglang.srt.managers.router.model_runner import global_server_args_dict
26
+ assert np.allclose(scaling, 1.0 / (head_dim**0.5))
27
+
28
+ from sglang.srt.managers.controller.model_runner import global_server_args_dict
20
29
 
21
- if global_server_args_dict.get("enable_flashinfer", False):
30
+ if not global_server_args_dict.get("disable_flashinfer", False):
22
31
  self.prefill_forward = self.prefill_forward_flashinfer
23
32
  self.extend_forward = self.prefill_forward_flashinfer
24
33
  self.decode_forward = self.decode_forward_flashinfer
34
+ # flashinfer now accepts float logit_cap argument
35
+ self.logit_cap = logit_cap if logit_cap > 0 else 0
25
36
  else:
26
37
  self.prefill_forward = self.prefill_forward_triton
27
38
  self.extend_forward = self.extend_forward_triton
28
39
  self.decode_forward = self.decode_forward_triton
40
+ self.logit_cap = logit_cap
29
41
 
30
42
  def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
31
43
  o = torch.empty_like(q)
@@ -38,6 +50,7 @@ class RadixAttention(nn.Module):
38
50
  input_metadata.start_loc,
39
51
  input_metadata.seq_lens,
40
52
  input_metadata.max_seq_len,
53
+ self.logit_cap,
41
54
  )
42
55
  self.store_kv_cache(k, v, input_metadata)
43
56
 
@@ -62,6 +75,7 @@ class RadixAttention(nn.Module):
62
75
  input_metadata.extend_seq_lens,
63
76
  input_metadata.max_seq_len,
64
77
  input_metadata.max_extend_len,
78
+ self.logit_cap,
65
79
  )
66
80
 
67
81
  return o
@@ -82,6 +96,7 @@ class RadixAttention(nn.Module):
82
96
  input_metadata.max_seq_len,
83
97
  input_metadata.other_kv_index,
84
98
  input_metadata.total_num_tokens,
99
+ self.logit_cap,
85
100
  )
86
101
 
87
102
  return o
@@ -89,19 +104,38 @@ class RadixAttention(nn.Module):
89
104
  def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
90
105
  self.store_kv_cache(k, v, input_metadata)
91
106
 
92
- o = input_metadata.prefill_wrapper.forward(
107
+ o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
93
108
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
94
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
109
+ k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
110
+ v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
111
+ logits_soft_cap=self.logit_cap,
95
112
  )
96
113
 
114
+ if input_metadata.no_prefix:
115
+ o = o1
116
+ else:
117
+ o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
118
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
119
+ input_metadata.token_to_kv_pool.kv_data[self.layer_id],
120
+ causal=False,
121
+ logits_soft_cap=self.logit_cap,
122
+ )
123
+
124
+ from flashinfer.cascade import merge_state
125
+ o, _ = merge_state(o1, s1, o2, s2)
126
+
127
+ if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
128
+ torch.cuda.synchronize()
129
+
97
130
  return o.view(-1, self.tp_q_head_num * self.head_dim)
98
131
 
99
132
  def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
100
133
  self.store_kv_cache(k, v, input_metadata)
101
134
 
102
- o = input_metadata.decode_wrapper.forward(
135
+ o = input_metadata.flashinfer_decode_wrapper.forward(
103
136
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
104
137
  input_metadata.token_to_kv_pool.kv_data[self.layer_id],
138
+ logits_soft_cap=self.logit_cap,
105
139
  )
106
140
 
107
141
  return o.view(-1, self.tp_q_head_num * self.head_dim)
@@ -5,7 +5,7 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from sglang.srt.managers.router.model_runner import global_server_args_dict
8
+ from sglang.srt.managers.controller.model_runner import global_server_args_dict
9
9
  from sglang.srt.utils import wrap_kernel_launcher
10
10
 
11
11
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
@@ -16,6 +16,12 @@ else:
16
16
  REDUCE_TORCH_TYPE = torch.float16
17
17
 
18
18
 
19
+ @triton.jit
20
+ def tanh(x):
21
+ # Tanh is just a scaled sigmoid
22
+ return 2 * tl.sigmoid(2 * x) - 1
23
+
24
+
19
25
  @triton.jit
20
26
  def _fwd_kernel_stage1(
21
27
  Q,
@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
35
41
  kv_group_num: tl.constexpr,
36
42
  BLOCK_DMODEL: tl.constexpr,
37
43
  BLOCK_N: tl.constexpr,
44
+ logit_cap: tl.constexpr,
38
45
  ):
39
46
  cur_batch = tl.program_id(0)
40
47
  cur_head = tl.program_id(1)
@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
77
84
  ).to(REDUCE_TRITON_TYPE)
78
85
  att_value = tl.sum(q[None, :] * k, 1)
79
86
  att_value *= sm_scale
87
+
88
+ if logit_cap > 0:
89
+ att_value = logit_cap * tanh(att_value / logit_cap)
90
+
80
91
  off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
81
92
  tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
82
93
 
@@ -165,6 +176,7 @@ def _token_att_m_fwd(
165
176
  B_Start_Loc,
166
177
  B_Seqlen,
167
178
  max_len_in_batch,
179
+ logit_cap,
168
180
  ):
169
181
  BLOCK = 32
170
182
  # shape constraints
@@ -223,6 +235,7 @@ def _token_att_m_fwd(
223
235
  kv_group_num=kv_group_num,
224
236
  BLOCK_DMODEL=Lk,
225
237
  BLOCK_N=BLOCK,
238
+ logit_cap=logit_cap,
226
239
  num_warps=num_warps,
227
240
  num_stages=1,
228
241
  )
@@ -304,6 +317,7 @@ def token_attention_fwd(
304
317
  max_len_in_batch,
305
318
  other_kv_index,
306
319
  total_num_tokens,
320
+ logit_cap=-1,
307
321
  att_m=None,
308
322
  ):
309
323
  if att_m is None:
@@ -320,6 +334,7 @@ def token_attention_fwd(
320
334
  b_start_loc,
321
335
  b_seq_len,
322
336
  max_len_in_batch,
337
+ logit_cap,
323
338
  )
324
339
  _token_softmax_reducev_fwd(
325
340
  att_m,
@@ -0,0 +1,113 @@
1
+ """A data parallel worker thread."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import queue
6
+ import threading
7
+ from typing import Callable, List
8
+
9
+ import uvloop
10
+ import zmq
11
+
12
+ from sglang.global_config import global_config
13
+ from sglang.srt.managers.controller.tp_worker import ModelTpClient
14
+ from sglang.srt.managers.io_struct import BatchTokenIDOut
15
+ from sglang.srt.server_args import PortArgs, ServerArgs
16
+ from sglang.srt.utils import kill_parent_process
17
+ from sglang.utils import get_exception_traceback
18
+
19
+ logger = logging.getLogger("srt.controller")
20
+ CHECKING_INTERVAL = 5
21
+
22
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
23
+
24
+
25
+ class DataParallelWorkerThread(threading.Thread):
26
+ def __init__(
27
+ self,
28
+ worker_id: int,
29
+ request_queue: queue.Queue,
30
+ detokenizer_port: int,
31
+ step_func: Callable,
32
+ ):
33
+ super(DataParallelWorkerThread, self).__init__()
34
+ self.worker_id = worker_id
35
+ self.request_queue = request_queue
36
+ self.liveness = True
37
+ self.request_dependency_delay = global_config.request_dependency_delay
38
+
39
+ context = zmq.asyncio.Context()
40
+ self.send_to_detokenizer = context.socket(zmq.PUSH)
41
+ self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
42
+
43
+ self.step = step_func
44
+
45
+ async def loop_for_forward(self):
46
+ while self.liveness:
47
+ requests = []
48
+ while not self.request_queue.empty():
49
+ requests.append(self.request_queue.get())
50
+
51
+ out_pyobjs: List[BatchTokenIDOut] = []
52
+ try:
53
+ out_pyobjs = await self.step(requests)
54
+ except Exception:
55
+ for r in requests:
56
+ self.request_queue.put(r)
57
+ logger.error(
58
+ f"Worker thread {self.worker_id}: "
59
+ f"failed to get back from Model Server\n"
60
+ f"{get_exception_traceback()}"
61
+ )
62
+ self.liveness = False
63
+ # Crash the whole server when there are any errors.
64
+ # TODO(lianmin): make this an option.
65
+ kill_parent_process()
66
+ return
67
+
68
+ for obj in out_pyobjs:
69
+ self.send_to_detokenizer.send_pyobj(obj)
70
+
71
+ # async sleep for receiving the subsequent request and avoiding cache miss
72
+ if len(out_pyobjs) != 0:
73
+ has_finished = any(
74
+ [obj.finished_reason is not None for obj in out_pyobjs]
75
+ )
76
+ if has_finished:
77
+ await asyncio.sleep(self.request_dependency_delay)
78
+ await asyncio.sleep(global_config.wait_for_new_request_delay)
79
+
80
+ async def monitoring(self):
81
+ while True:
82
+ await asyncio.sleep(CHECKING_INTERVAL)
83
+ # can plug in monitoring logic here
84
+
85
+ def run(self):
86
+ logger.info(f"DataParallelWorkerThread {self.worker_id} start")
87
+ loop = asyncio.new_event_loop()
88
+ asyncio.set_event_loop(loop)
89
+ loop.create_task(self.monitoring())
90
+ loop.run_until_complete(self.loop_for_forward())
91
+
92
+
93
+ def start_data_parallel_worker(
94
+ server_args: ServerArgs,
95
+ port_args: PortArgs,
96
+ model_overide_args,
97
+ gpu_ids: List[int],
98
+ worker_id: int,
99
+ ):
100
+ model_tp_client = ModelTpClient(
101
+ gpu_ids,
102
+ server_args,
103
+ port_args.model_port_args[worker_id],
104
+ model_overide_args,
105
+ )
106
+ worker_thread = DataParallelWorkerThread(
107
+ worker_id=worker_id,
108
+ request_queue=queue.Queue(),
109
+ detokenizer_port=port_args.detokenizer_port,
110
+ step_func=model_tp_client.step,
111
+ )
112
+ worker_thread.start()
113
+ return worker_thread