sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,8 @@
1
+ """Logits processing."""
2
+
3
+ import dataclasses
4
+ from typing import List, Union
5
+
1
6
  import torch
2
7
  from torch import nn
3
8
  from vllm.distributed import (
@@ -8,6 +13,45 @@ from vllm.distributed import (
8
13
  from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
9
14
 
10
15
 
16
+ @dataclasses.dataclass
17
+ class LogitProcessorOutput:
18
+ # The logits of the next tokens. shape: [#seq, vocab_size]
19
+ next_token_logits: torch.Tensor
20
+ # The logprobs of the next tokens. shape: [#seq, vocab_size]
21
+ next_token_logprobs: torch.Tensor
22
+
23
+ # The normlaized logprobs of prompts. shape: [#seq]
24
+ normalized_prompt_logprobs: torch.Tensor
25
+ # The logprobs of prefill tokens. shape: [#token, vocab_size]
26
+ prefill_token_logprobs: torch.Tensor
27
+
28
+ # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
29
+ prefill_top_logprobs: List
30
+ # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
31
+ decode_top_logprobs: List
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class LogitsMetadata:
36
+ forward_mode: ForwardMode
37
+ extend_seq_lens: torch.Tensor
38
+ extend_start_loc: torch.Tensor
39
+
40
+ # For logprobs
41
+ return_logprob: bool
42
+ top_logprobs_nums: List[int]
43
+
44
+ @classmethod
45
+ def from_input_metadata(cls, input_metadata: InputMetadata):
46
+ return cls(
47
+ forward_mode=input_metadata.forward_mode,
48
+ extend_seq_lens=input_metadata.extend_seq_lens,
49
+ extend_start_loc=input_metadata.extend_start_loc,
50
+ return_logprob=input_metadata.return_logprob,
51
+ top_logprobs_nums=input_metadata.top_logprobs_nums,
52
+ )
53
+
54
+
11
55
  class LogitsProcessor(nn.Module):
12
56
  def __init__(self, config):
13
57
  super().__init__()
@@ -15,14 +59,14 @@ class LogitsProcessor(nn.Module):
15
59
  self.tp_size = get_tensor_model_parallel_world_size()
16
60
 
17
61
  def _get_normalized_prompt_logprobs(
18
- self, prefill_token_logprobs, input_metadata: InputMetadata
62
+ self, prefill_token_logprobs, logits_metadata: LogitsMetadata
19
63
  ):
20
64
  logprobs_cumsum = torch.cumsum(
21
65
  prefill_token_logprobs, dim=0, dtype=torch.float32
22
66
  )
23
67
 
24
- start = input_metadata.extend_start_loc.clone()
25
- end = start + input_metadata.extend_seq_lens - 2
68
+ start = logits_metadata.extend_start_loc.clone()
69
+ end = start + logits_metadata.extend_seq_lens - 2
26
70
  start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
27
71
  end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
28
72
  sum_logp = (
@@ -31,16 +75,17 @@ class LogitsProcessor(nn.Module):
31
75
  + prefill_token_logprobs[start]
32
76
  )
33
77
  normalized_prompt_logprobs = sum_logp / (
34
- (input_metadata.extend_seq_lens - 1).clamp(min=1)
78
+ (logits_metadata.extend_seq_lens - 1).clamp(min=1)
35
79
  )
36
80
 
37
81
  return normalized_prompt_logprobs
38
82
 
39
- def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
40
- if input_metadata.forward_mode == ForwardMode.DECODE:
83
+ def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
84
+ # TODO: vectorize the code below
85
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
41
86
  decode_top_logprobs = []
42
87
  for i in range(all_logprobs.shape[0]):
43
- k = input_metadata.top_logprobs_nums[i]
88
+ k = logits_metadata.top_logprobs_nums[i]
44
89
  t = all_logprobs[i].topk(k)
45
90
  v_cpu = t.values.tolist()
46
91
  p_cpu = t.indices.tolist()
@@ -49,14 +94,13 @@ class LogitsProcessor(nn.Module):
49
94
  else:
50
95
  prefill_top_logprobs, decode_top_logprobs = [], []
51
96
  pt = 0
52
- # NOTE: the GPU-CPU overhead can be reduced
53
- extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
97
+ extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
54
98
  for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
55
99
  if extend_seq_len == 0:
56
100
  prefill_top_logprobs.append([])
57
101
  decode_top_logprobs.append([])
58
102
  continue
59
- k = input_metadata.top_logprobs_nums[i]
103
+ k = logits_metadata.top_logprobs_nums[i]
60
104
  t = all_logprobs[pt : pt + extend_seq_len].topk(k)
61
105
  vs_cpu = t.values.tolist()
62
106
  ps_cpu = t.indices.tolist()
@@ -68,19 +112,26 @@ class LogitsProcessor(nn.Module):
68
112
 
69
113
  return prefill_top_logprobs, decode_top_logprobs
70
114
 
71
- def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
72
- # Get last index for next token prediction, except for DECODE mode.
73
- last_index = None
74
- if input_metadata.forward_mode != ForwardMode.DECODE:
75
- last_index = (
76
- torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
77
- - 1
78
- )
115
+ def forward(
116
+ self,
117
+ input_ids,
118
+ hidden_states,
119
+ weight,
120
+ logits_metadata: Union[LogitsMetadata, InputMetadata],
121
+ ):
122
+ if isinstance(logits_metadata, InputMetadata):
123
+ logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
124
+ assert isinstance(logits_metadata, LogitsMetadata)
79
125
 
80
- # Get the last hidden states and last logits
81
- if input_metadata.forward_mode == ForwardMode.DECODE:
126
+ # Get the last hidden states and last logits for the next token prediction
127
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
128
+ last_index = None
82
129
  last_hidden = hidden_states
83
130
  else:
131
+ last_index = (
132
+ torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
133
+ - 1
134
+ )
84
135
  last_hidden = hidden_states[last_index]
85
136
 
86
137
  last_logits = torch.matmul(last_hidden, weight.T)
@@ -88,13 +139,24 @@ class LogitsProcessor(nn.Module):
88
139
  last_logits = tensor_model_parallel_all_gather(last_logits)
89
140
  last_logits = last_logits[:, : self.config.vocab_size]
90
141
 
142
+ if hasattr(self.config, "final_logit_softcapping"):
143
+ last_logits /= self.config.final_logit_softcapping
144
+ last_logits = torch.tanh(last_logits)
145
+ last_logits *= self.config.final_logit_softcapping
146
+
91
147
  # Return only last_logits if logprob is not requested
92
- if not input_metadata.return_logprob:
93
- hidden_states = None
94
- return last_logits, (None, None, None, None, None)
148
+ if not logits_metadata.return_logprob:
149
+ return LogitProcessorOutput(
150
+ next_token_logits=last_logits,
151
+ next_token_logprobs=None,
152
+ normalized_prompt_logprobs=None,
153
+ prefill_token_logprobs=None,
154
+ prefill_top_logprobs=None,
155
+ decode_top_logprobs=None,
156
+ )
95
157
  else:
96
158
  # When logprob is requested, compute the logits for all tokens.
97
- if input_metadata.forward_mode == ForwardMode.DECODE:
159
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
98
160
  all_logits = last_logits
99
161
  else:
100
162
  all_logits = torch.matmul(hidden_states, weight.T)
@@ -106,25 +168,25 @@ class LogitsProcessor(nn.Module):
106
168
  del all_logits
107
169
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
108
170
 
109
- return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
171
+ # Get the logprob of top-k tokens
172
+ return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
110
173
  if return_top_logprob:
111
174
  prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
112
- all_logprobs, input_metadata
175
+ all_logprobs, logits_metadata
113
176
  )
114
177
  else:
115
178
  prefill_top_logprobs = decode_top_logprobs = None
116
179
 
117
- if input_metadata.forward_mode == ForwardMode.DECODE:
118
- last_logprobs = all_logprobs
119
- return last_logits, (
120
- None,
121
- None,
122
- None,
123
- decode_top_logprobs,
124
- last_logprobs,
180
+ if logits_metadata.forward_mode == ForwardMode.DECODE:
181
+ return LogitProcessorOutput(
182
+ next_token_logits=last_logits,
183
+ next_token_logprobs=all_logprobs,
184
+ normalized_prompt_logprobs=None,
185
+ prefill_token_logprobs=None,
186
+ prefill_top_logprobs=None,
187
+ decode_top_logprobs=decode_top_logprobs,
125
188
  )
126
189
  else:
127
- # Compute the logprobs for the last token of each request.
128
190
  last_logprobs = all_logprobs[last_index]
129
191
 
130
192
  # Compute the logprobs and normalized logprobs for the prefill tokens.
@@ -135,14 +197,16 @@ class LogitsProcessor(nn.Module):
135
197
  ]
136
198
 
137
199
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
138
- prefill_token_logprobs, input_metadata
200
+ prefill_token_logprobs, logits_metadata
139
201
  )
140
- return last_logits, (
141
- prefill_token_logprobs,
142
- normalized_prompt_logprobs,
143
- prefill_top_logprobs,
144
- decode_top_logprobs,
145
- last_logprobs,
202
+
203
+ return LogitProcessorOutput(
204
+ next_token_logits=last_logits,
205
+ next_token_logprobs=last_logprobs,
206
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
207
+ prefill_token_logprobs=prefill_token_logprobs,
208
+ prefill_top_logprobs=prefill_top_logprobs,
209
+ decode_top_logprobs=decode_top_logprobs,
146
210
  )
147
211
 
148
212
 
@@ -1,52 +1,52 @@
1
- import torch
1
+ """Radix attention."""
2
+
2
3
  import numpy as np
4
+ import torch
5
+ from flashinfer.cascade import merge_state
3
6
  from torch import nn
4
7
 
5
- from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
8
+ from sglang.global_config import global_config
6
9
  from sglang.srt.layers.extend_attention import extend_attention_fwd
7
10
  from sglang.srt.layers.token_attention import token_attention_fwd
8
11
  from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
9
12
 
10
13
 
11
14
  class RadixAttention(nn.Module):
12
- def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
15
+ def __init__(
16
+ self,
17
+ num_heads: int,
18
+ head_dim: int,
19
+ scaling: float,
20
+ num_kv_heads: int,
21
+ layer_id: int,
22
+ logit_cap: int = -1,
23
+ ):
13
24
  super().__init__()
14
25
  self.tp_q_head_num = num_heads
15
26
  self.tp_k_head_num = num_kv_heads
16
27
  self.tp_v_head_num = num_kv_heads
17
28
  self.head_dim = head_dim
29
+ self.scaling = scaling
18
30
  self.layer_id = layer_id
19
- self.logit_cap = logit_cap
20
-
21
- assert np.allclose(scaling, 1.0 / (head_dim**0.5))
22
31
 
23
32
  from sglang.srt.managers.controller.model_runner import global_server_args_dict
24
33
 
25
- if global_server_args_dict.get("enable_flashinfer", False):
34
+ if not global_server_args_dict.get("disable_flashinfer", False):
26
35
  self.prefill_forward = self.prefill_forward_flashinfer
27
36
  self.extend_forward = self.prefill_forward_flashinfer
28
37
  self.decode_forward = self.decode_forward_flashinfer
38
+ # flashinfer now accepts float logit_cap argument
39
+ self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
29
40
  else:
30
41
  self.prefill_forward = self.prefill_forward_triton
31
42
  self.extend_forward = self.extend_forward_triton
32
43
  self.decode_forward = self.decode_forward_triton
44
+ self.logit_cap = logit_cap if logit_cap is not None else 0
33
45
 
34
46
  def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
35
- o = torch.empty_like(q)
36
-
37
- context_attention_fwd(
38
- q.view(-1, self.tp_q_head_num, self.head_dim),
39
- k,
40
- v,
41
- o.view(-1, self.tp_q_head_num, self.head_dim),
42
- input_metadata.start_loc,
43
- input_metadata.seq_lens,
44
- input_metadata.max_seq_len,
45
- self.logit_cap,
46
- )
47
- self.store_kv_cache(k, v, input_metadata)
48
-
49
- return o
47
+ # In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
48
+ # See the extend_forward_xxx functions.
49
+ raise NotImplementedError()
50
50
 
51
51
  def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
52
52
  o = torch.empty_like(q)
@@ -67,7 +67,8 @@ class RadixAttention(nn.Module):
67
67
  input_metadata.extend_seq_lens,
68
68
  input_metadata.max_seq_len,
69
69
  input_metadata.max_extend_len,
70
- self.logit_cap,
70
+ sm_scale=self.scaling,
71
+ logit_cap=self.logit_cap,
71
72
  )
72
73
 
73
74
  return o
@@ -88,27 +89,50 @@ class RadixAttention(nn.Module):
88
89
  input_metadata.max_seq_len,
89
90
  input_metadata.other_kv_index,
90
91
  input_metadata.total_num_tokens,
91
- self.logit_cap,
92
+ sm_scale=self.scaling,
93
+ logit_cap=self.logit_cap,
92
94
  )
93
95
 
94
96
  return o
95
97
 
96
98
  def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
97
- self.store_kv_cache(k, v, input_metadata)
98
-
99
- o = input_metadata.prefill_wrapper.forward(
99
+ o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
100
100
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
101
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
101
+ k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
102
+ v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
103
+ causal=True,
104
+ sm_scale=self.scaling,
105
+ logits_soft_cap=self.logit_cap,
102
106
  )
103
107
 
108
+ if input_metadata.no_prefix:
109
+ o = o1
110
+ else:
111
+ o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
112
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
113
+ input_metadata.token_to_kv_pool.kv_data[self.layer_id],
114
+ causal=False,
115
+ sm_scale=self.scaling,
116
+ logits_soft_cap=self.logit_cap,
117
+ )
118
+
119
+ o, _ = merge_state(o1, s1, o2, s2)
120
+
121
+ self.store_kv_cache(k, v, input_metadata)
122
+
123
+ if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
124
+ torch.cuda.synchronize()
125
+
104
126
  return o.view(-1, self.tp_q_head_num * self.head_dim)
105
127
 
106
128
  def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
107
129
  self.store_kv_cache(k, v, input_metadata)
108
130
 
109
- o = input_metadata.decode_wrapper.forward(
131
+ o = input_metadata.flashinfer_decode_wrapper.forward(
110
132
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
111
133
  input_metadata.token_to_kv_pool.kv_data[self.layer_id],
134
+ sm_scale=self.scaling,
135
+ logits_soft_cap=self.logit_cap,
112
136
  )
113
137
 
114
138
  return o.view(-1, self.tp_q_head_num * self.head_dim)
@@ -176,6 +176,7 @@ def _token_att_m_fwd(
176
176
  B_Start_Loc,
177
177
  B_Seqlen,
178
178
  max_len_in_batch,
179
+ sm_scale,
179
180
  logit_cap,
180
181
  ):
181
182
  BLOCK = 32
@@ -183,7 +184,6 @@ def _token_att_m_fwd(
183
184
  Lq, Lk = q.shape[-1], k_buffer.shape[-1]
184
185
  assert Lq == Lk
185
186
  assert Lk in {16, 32, 64, 128, 256}
186
- sm_scale = 1.0 / (Lk**0.5)
187
187
 
188
188
  batch, head_num = B_req_idx.shape[0], q.shape[1]
189
189
 
@@ -317,6 +317,7 @@ def token_attention_fwd(
317
317
  max_len_in_batch,
318
318
  other_kv_index,
319
319
  total_num_tokens,
320
+ sm_scale=None,
320
321
  logit_cap=-1,
321
322
  att_m=None,
322
323
  ):
@@ -324,6 +325,7 @@ def token_attention_fwd(
324
325
  att_m = torch.empty(
325
326
  (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
326
327
  )
328
+ sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
327
329
 
328
330
  _token_att_m_fwd(
329
331
  q,
@@ -334,6 +336,7 @@ def token_attention_fwd(
334
336
  b_start_loc,
335
337
  b_seq_len,
336
338
  max_len_in_batch,
339
+ sm_scale,
337
340
  logit_cap,
338
341
  )
339
342
  _token_softmax_reducev_fwd(
@@ -1,9 +1,10 @@
1
1
  """A data parallel worker thread."""
2
+
2
3
  import asyncio
3
4
  import logging
4
5
  import queue
5
6
  import threading
6
- from typing import List, Callable
7
+ from typing import Callable, List
7
8
 
8
9
  import uvloop
9
10
  import zmq
@@ -69,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
69
70
 
70
71
  # async sleep for receiving the subsequent request and avoiding cache miss
71
72
  if len(out_pyobjs) != 0:
72
- has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
73
+ has_finished = any(
74
+ [obj.finished_reason is not None for obj in out_pyobjs]
75
+ )
73
76
  if has_finished:
74
77
  await asyncio.sleep(self.request_dependency_delay)
75
78
  await asyncio.sleep(global_config.wait_for_new_request_delay)
@@ -107,4 +110,4 @@ def start_data_parallel_worker(
107
110
  step_func=model_tp_client.step,
108
111
  )
109
112
  worker_thread.start()
110
- return worker_thread
113
+ return worker_thread