sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,6 @@
1
+ """Meta data for requests and batches"""
2
+
3
+ import warnings
1
4
  from dataclasses import dataclass
2
5
  from enum import IntEnum, auto
3
6
  from typing import List
@@ -5,9 +8,13 @@ from typing import List
5
8
  import numpy as np
6
9
  import torch
7
10
 
8
- from sglang.srt.managers.router.radix_cache import RadixCache
11
+ from sglang.srt.constrained import RegexGuide
12
+ from sglang.srt.constrained.jump_forward import JumpForwardMap
13
+ from sglang.srt.managers.controller.radix_cache import RadixCache
9
14
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
10
15
 
16
+ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17
+
11
18
 
12
19
  class ForwardMode(IntEnum):
13
20
  PREFILL = auto()
@@ -15,33 +22,62 @@ class ForwardMode(IntEnum):
15
22
  DECODE = auto()
16
23
 
17
24
 
18
- class FinishReason(IntEnum):
19
- EOS_TOKEN = auto()
20
- LENGTH = auto()
21
- STOP_STR = auto()
25
+ class BaseFinishReason:
26
+ def __init__(self, is_error: bool = False):
27
+ self.is_error = is_error
28
+
29
+ def __str__(self):
30
+ raise NotImplementedError("Subclasses must implement this method")
31
+
32
+
33
+ class FINISH_MATCHED_TOKEN(BaseFinishReason):
34
+ def __init__(self, matched: int | List[int]):
35
+ super().__init__()
36
+ self.matched = matched
37
+
38
+ def __str__(self) -> str:
39
+ return f"FINISH_MATCHED_TOKEN: {self.matched}"
40
+
41
+
42
+ class FINISH_LENGTH(BaseFinishReason):
43
+ def __init__(self, length: int):
44
+ super().__init__()
45
+ self.length = length
46
+
47
+ def __str__(self) -> str:
48
+ return f"FINISH_LENGTH: {self.length}"
49
+
50
+
51
+ class FINISH_MATCHED_STR(BaseFinishReason):
52
+ def __init__(self, matched: str):
53
+ super().__init__()
54
+ self.matched = matched
55
+
56
+ def __str__(self) -> str:
57
+ return f"FINISH_MATCHED_STR: {self.matched}"
22
58
 
23
- @staticmethod
24
- def to_str(reason):
25
- if reason == FinishReason.EOS_TOKEN:
26
- return None
27
- elif reason == FinishReason.LENGTH:
28
- return "length"
29
- elif reason == FinishReason.STOP_STR:
30
- return "stop"
31
- else:
32
- return None
59
+
60
+ class FINISH_ABORT(BaseFinishReason):
61
+ def __init__(self):
62
+ super().__init__(is_error=True)
63
+
64
+ def __str__(self) -> str:
65
+ return "FINISH_ABORT"
33
66
 
34
67
 
35
68
  class Req:
36
- def __init__(self, rid, input_text, input_ids):
69
+ def __init__(self, rid, origin_input_text, origin_input_ids):
37
70
  self.rid = rid
38
- self.input_text = input_text
39
- self.input_ids = input_ids
40
- self.output_ids = []
71
+ self.origin_input_text = origin_input_text
72
+ self.origin_input_ids_unpadded = origin_input_ids # Before image padding
73
+ self.origin_input_ids = origin_input_ids
74
+ self.output_ids = [] # Each decode stage's output ids
75
+ self.input_ids = None # input_ids = origin_input_ids + output_ids
41
76
 
42
- # Since jump forward may retokenize the prompt with partial outputs,
43
- # we maintain the original prompt length to report the correct usage.
44
- self.prompt_tokens = len(input_ids)
77
+ # For incremental decode
78
+ self.decoded_text = ""
79
+ self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
80
+ self.read_offset = None
45
81
 
46
82
  # The number of decoded tokens for token usage report. Note that
47
83
  # this does not include the jump forward tokens.
@@ -57,12 +93,12 @@ class Req:
57
93
  self.sampling_params = None
58
94
  self.stream = False
59
95
 
60
- # Check finish
61
96
  self.tokenizer = None
62
- self.finished = False
63
- self.finish_reason = None
64
- self.hit_stop_str = None
65
97
 
98
+ # Check finish
99
+ self.finished_reason = None
100
+
101
+ # Prefix info
66
102
  self.extend_input_len = 0
67
103
  self.prefix_indices = []
68
104
  self.last_node = None
@@ -73,80 +109,81 @@ class Req:
73
109
  self.top_logprobs_num = 0
74
110
  self.normalized_prompt_logprob = None
75
111
  self.prefill_token_logprobs = None
76
- self.decode_token_logprobs = None
77
112
  self.prefill_top_logprobs = None
78
- self.decode_top_logprobs = None
113
+ self.decode_token_logprobs = []
114
+ self.decode_top_logprobs = []
115
+ # The tokens is prefilled but need to be considered as decode tokens
116
+ # and should be updated for the decode logprobs
117
+ self.last_update_decode_tokens = 0
79
118
 
80
119
  # Constrained decoding
81
- self.regex_fsm = None
82
- self.regex_fsm_state = 0
83
- self.jump_forward_map = None
84
- self.output_and_jump_forward_str = ""
120
+ self.regex_fsm: RegexGuide = None
121
+ self.regex_fsm_state: int = 0
122
+ self.jump_forward_map: JumpForwardMap = None
123
+
124
+ # whether request reached finished condition
125
+ def finished(self) -> bool:
126
+ return self.finished_reason is not None
127
+
128
+ # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
129
+ def init_detokenize_incrementally(self):
130
+ first_iter = self.surr_offset is None or self.read_offset is None
131
+
132
+ if first_iter:
133
+ self.read_offset = len(self.origin_input_ids_unpadded)
134
+ self.surr_offset = max(
135
+ self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
136
+ )
85
137
 
86
- def max_new_tokens(self):
87
- return self.sampling_params.max_new_tokens
138
+ all_ids = self.origin_input_ids_unpadded + self.output_ids
139
+ surr_ids = all_ids[self.surr_offset : self.read_offset]
140
+ read_ids = all_ids[self.surr_offset :]
88
141
 
89
- def jump_forward_and_retokenize(self, jump_forward_str, next_state):
90
- old_output_str = self.tokenizer.decode(self.output_ids)
91
- # FIXME: This logic does not really solve the problem of determining whether
92
- # there should be a leading space.
93
- first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
94
- first_token = (
95
- first_token.decode() if isinstance(first_token, bytes) else first_token
96
- )
97
- if first_token.startswith("▁"):
98
- old_output_str = " " + old_output_str
99
- if self.input_text is None:
100
- # TODO(lmzheng): This can be wrong. Check with Liangsheng.
101
- self.input_text = self.tokenizer.decode(self.input_ids)
102
- new_input_string = (
103
- self.input_text
104
- + self.output_and_jump_forward_str
105
- + old_output_str
106
- + jump_forward_str
107
- )
108
- new_input_ids = self.tokenizer.encode(new_input_string)
109
- if self.pixel_values is not None:
110
- # NOTE: This is a hack because the old input_ids contains the image padding
111
- jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
112
- else:
113
- jump_forward_tokens_len = (
114
- len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
115
- )
142
+ return surr_ids, read_ids, len(all_ids)
116
143
 
117
- # print("=" * 100)
118
- # print(f"Catch jump forward:\n{jump_forward_str}")
119
- # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
120
- # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
144
+ def detokenize_incrementally(self, inplace: bool = True):
145
+ surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
121
146
 
122
- self.input_ids = new_input_ids
123
- self.output_ids = []
124
- self.sampling_params.max_new_tokens = max(
125
- self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
147
+ surr_text = self.tokenizer.decode(
148
+ surr_ids,
149
+ skip_special_tokens=self.sampling_params.skip_special_tokens,
150
+ spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
126
151
  )
127
- self.regex_fsm_state = next_state
128
- self.output_and_jump_forward_str = (
129
- self.output_and_jump_forward_str + old_output_str + jump_forward_str
152
+ new_text = self.tokenizer.decode(
153
+ read_ids,
154
+ skip_special_tokens=self.sampling_params.skip_special_tokens,
155
+ spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
130
156
  )
131
157
 
132
- # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
133
- # print("*" * 100)
158
+ if len(new_text) > len(surr_text) and not new_text.endswith("�"):
159
+ new_text = new_text[len(surr_text) :]
160
+ if inplace:
161
+ self.decoded_text += new_text
162
+ self.surr_offset = self.read_offset
163
+ self.read_offset = num_all_tokens
164
+
165
+ return True, new_text
166
+
167
+ return False, ""
168
+
169
+ def max_new_tokens(self):
170
+ return self.sampling_params.max_new_tokens
134
171
 
135
172
  def check_finished(self):
136
- if self.finished:
173
+ if self.finished():
137
174
  return
138
175
 
139
176
  if len(self.output_ids) >= self.sampling_params.max_new_tokens:
140
- self.finished = True
141
- self.finish_reason = FinishReason.LENGTH
177
+ self.finished_reason = FINISH_LENGTH(len(self.output_ids))
142
178
  return
143
179
 
144
180
  if (
145
181
  self.output_ids[-1] == self.tokenizer.eos_token_id
146
- and self.sampling_params.ignore_eos == False
182
+ and not self.sampling_params.ignore_eos
147
183
  ):
148
- self.finished = True
149
- self.finish_reason = FinishReason.EOS_TOKEN
184
+ self.finished_reason = FINISH_MATCHED_TOKEN(
185
+ matched=self.tokenizer.eos_token_id
186
+ )
150
187
  return
151
188
 
152
189
  if len(self.sampling_params.stop_strs) > 0:
@@ -155,14 +192,62 @@ class Req:
155
192
  )
156
193
 
157
194
  for stop_str in self.sampling_params.stop_strs:
158
- if stop_str in tail_str:
159
- self.finished = True
160
- self.finish_reason = FinishReason.STOP_STR
161
- self.hit_stop_str = stop_str
195
+ if stop_str in tail_str or stop_str in self.decoded_text:
196
+ self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
162
197
  return
163
198
 
199
+ def jump_forward_and_retokenize(self, jump_forward_str, next_state):
200
+ if self.origin_input_text is None:
201
+ # Recovering text can only use unpadded ids
202
+ self.origin_input_text = self.tokenizer.decode(
203
+ self.origin_input_ids_unpadded
204
+ )
205
+
206
+ all_text = self.origin_input_text + self.decoded_text + jump_forward_str
207
+ all_ids = self.tokenizer.encode(all_text)
208
+ prompt_tokens = len(self.origin_input_ids_unpadded)
209
+
210
+ if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
211
+ # TODO(lsyin): fix token fusion
212
+ warnings.warn(
213
+ "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
214
+ )
215
+ return False
216
+
217
+ old_output_ids = self.output_ids
218
+ self.output_ids = all_ids[prompt_tokens:]
219
+ self.decoded_text = self.decoded_text + jump_forward_str
220
+ self.surr_offset = prompt_tokens
221
+ self.read_offset = len(all_ids)
222
+
223
+ # NOTE: A trick to reduce the surrouding tokens decoding overhead
224
+ for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
225
+ surr_text_ = self.tokenizer.decode(
226
+ all_ids[self.read_offset - i : self.read_offset]
227
+ )
228
+ if not surr_text_.endswith("�"):
229
+ self.surr_offset = self.read_offset - i
230
+ break
231
+
232
+ self.regex_fsm_state = next_state
233
+
234
+ if self.return_logprob:
235
+ # For fast-forward part's logprobs
236
+ k = 0
237
+ for i, old_id in enumerate(old_output_ids):
238
+ if old_id == self.output_ids[i]:
239
+ k = k + 1
240
+ else:
241
+ break
242
+ self.decode_token_logprobs = self.decode_token_logprobs[:k]
243
+ self.decode_top_logprobs = self.decode_top_logprobs[:k]
244
+ self.logprob_start_len = prompt_tokens + k
245
+ self.last_update_decode_tokens = len(self.output_ids) - k
246
+
247
+ return True
248
+
164
249
  def __repr__(self):
165
- return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
250
+ return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
166
251
 
167
252
 
168
253
  @dataclass
@@ -218,6 +303,10 @@ class Batch:
218
303
  def is_empty(self):
219
304
  return len(self.reqs) == 0
220
305
 
306
+ # whether batch has at least 1 streaming request
307
+ def has_stream(self) -> bool:
308
+ return any(r.stream for r in self.reqs)
309
+
221
310
  def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
222
311
  device = "cuda"
223
312
  bs = len(self.reqs)
@@ -333,8 +422,12 @@ class Batch:
333
422
 
334
423
  def retract_decode(self):
335
424
  sorted_indices = [i for i in range(len(self.reqs))]
425
+ # TODO(lsyin): improve the priority of retraction
336
426
  sorted_indices.sort(
337
- key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
427
+ key=lambda i: (
428
+ len(self.reqs[i].output_ids),
429
+ -len(self.reqs[i].origin_input_ids),
430
+ ),
338
431
  reverse=True,
339
432
  )
340
433
 
@@ -353,18 +446,22 @@ class Batch:
353
446
  ][last_uncached_pos : seq_lens_cpu[idx]]
354
447
  self.token_to_kv_pool.dec_refs(token_indices)
355
448
 
449
+ # release the last node
356
450
  self.tree_cache.dec_lock_ref(req.last_node)
451
+
357
452
  req.prefix_indices = None
358
453
  req.last_node = None
359
454
  req.extend_input_len = 0
360
- req.output_ids = []
361
- req.regex_fsm_state = 0
455
+
456
+ # For incremental logprobs
457
+ req.last_update_decode_tokens = 0
458
+ req.logprob_start_len = 10**9
362
459
 
363
460
  self.filter_batch(sorted_indices)
364
461
 
365
462
  return retracted_reqs
366
463
 
367
- def check_for_jump_forward(self):
464
+ def check_for_jump_forward(self, model_runner):
368
465
  jump_forward_reqs = []
369
466
  filter_indices = [i for i in range(len(self.reqs))]
370
467
 
@@ -372,18 +469,54 @@ class Batch:
372
469
 
373
470
  for i, req in enumerate(self.reqs):
374
471
  if req.jump_forward_map is not None:
375
- res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
376
- if res is not None:
377
- jump_forward_str, next_state = res
378
- if len(jump_forward_str) <= 1:
472
+ jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
473
+ req.regex_fsm_state
474
+ )
475
+ if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
476
+ suffix_bytes = []
477
+ continuation_range = range(0x80, 0xC0)
478
+ cur_state = req.regex_fsm_state
479
+ while (
480
+ len(jump_forward_bytes)
481
+ and jump_forward_bytes[0][0] in continuation_range
482
+ ):
483
+ # continuation bytes
484
+ byte_edge = jump_forward_bytes.pop(0)
485
+ suffix_bytes.append(byte_edge[0])
486
+ cur_state = byte_edge[1]
487
+
488
+ suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
489
+ suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
490
+
491
+ # Current ids, for cache and revert
492
+ cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
493
+ cur_output_ids = req.output_ids
494
+
495
+ req.output_ids.extend(suffix_ids)
496
+ decode_res, new_text = req.detokenize_incrementally(inplace=False)
497
+ if not decode_res:
498
+ req.output_ids = cur_output_ids
379
499
  continue
380
500
 
381
- if req_pool_indices_cpu is None:
382
- req_pool_indices_cpu = self.req_pool_indices.tolist()
501
+ (
502
+ jump_forward_str,
503
+ next_state,
504
+ ) = req.jump_forward_map.jump_forward_symbol(cur_state)
505
+
506
+ # Make the incrementally decoded text part of jump_forward_str
507
+ # so that the UTF-8 will not corrupt
508
+ jump_forward_str = new_text + jump_forward_str
509
+ if not req.jump_forward_and_retokenize(
510
+ jump_forward_str, next_state
511
+ ):
512
+ req.output_ids = cur_output_ids
513
+ continue
383
514
 
384
515
  # insert the old request into tree_cache
516
+ if req_pool_indices_cpu is None:
517
+ req_pool_indices_cpu = self.req_pool_indices.tolist()
385
518
  self.tree_cache.cache_req(
386
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
519
+ token_ids=cur_all_ids,
387
520
  last_uncached_pos=len(req.prefix_indices),
388
521
  req_pool_idx=req_pool_indices_cpu[i],
389
522
  )
@@ -391,8 +524,17 @@ class Batch:
391
524
  # unlock the last node
392
525
  self.tree_cache.dec_lock_ref(req.last_node)
393
526
 
394
- # jump-forward
395
- req.jump_forward_and_retokenize(jump_forward_str, next_state)
527
+ # re-applying image padding
528
+ if req.pixel_values is not None:
529
+ (
530
+ req.origin_input_ids,
531
+ req.image_offset,
532
+ ) = model_runner.model.pad_input_ids(
533
+ req.origin_input_ids_unpadded,
534
+ req.pad_value,
535
+ req.pixel_values.shape,
536
+ req.image_size,
537
+ )
396
538
 
397
539
  jump_forward_reqs.append(req)
398
540
  filter_indices.remove(i)
@@ -515,7 +657,7 @@ class Batch:
515
657
  if req.regex_fsm is not None:
516
658
  allowed_mask.zero_()
517
659
  allowed_mask[
518
- req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
660
+ req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
519
661
  ] = 1
520
662
  logits[i].masked_fill_(~allowed_mask, float("-inf"))
521
663
 
@@ -534,7 +676,7 @@ class Batch:
534
676
  batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
535
677
  for i, req in enumerate(self.reqs):
536
678
  if req.regex_fsm is not None:
537
- req.regex_fsm_state = req.regex_fsm.next_state(
679
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
538
680
  req.regex_fsm_state, batch_next_token_ids_cpu[i]
539
681
  )
540
682
 
@@ -0,0 +1,191 @@
1
+ """
2
+ A controller that manages multiple data parallel workers.
3
+ Each data parallel worker can manage multiple tensor parallel workers.
4
+ """
5
+
6
+ import asyncio
7
+ import logging
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from enum import Enum, auto
10
+ from typing import Dict
11
+
12
+ import zmq
13
+ import zmq.asyncio
14
+
15
+ from sglang.global_config import global_config
16
+ from sglang.srt.managers.controller.dp_worker import (
17
+ DataParallelWorkerThread,
18
+ start_data_parallel_worker,
19
+ )
20
+ from sglang.srt.managers.io_struct import (
21
+ AbortReq,
22
+ FlushCacheReq,
23
+ TokenizedGenerateReqInput,
24
+ )
25
+ from sglang.srt.server_args import PortArgs, ServerArgs
26
+ from sglang.utils import get_exception_traceback
27
+
28
+ logger = logging.getLogger("srt.controller")
29
+
30
+
31
+ class LoadBalanceMethod(Enum):
32
+ ROUND_ROBIN = auto()
33
+ SHORTEST_QUEUE = auto()
34
+
35
+ @classmethod
36
+ def from_str(cls, method: str):
37
+ method = method.upper()
38
+ try:
39
+ return cls[method]
40
+ except KeyError as exc:
41
+ raise ValueError(f"Invalid load balance method: {method}") from exc
42
+
43
+
44
+ class Controller:
45
+ def __init__(
46
+ self,
47
+ load_balance_method: str,
48
+ server_args: ServerArgs,
49
+ port_args: PortArgs,
50
+ model_overide_args,
51
+ ):
52
+ self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
53
+ self.server_args = server_args
54
+ self.port_args = port_args
55
+
56
+ if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
57
+ self.round_robin_counter = 0
58
+
59
+ self.dispatch_lookup = {
60
+ LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
61
+ LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
62
+ }
63
+ self.dispatching = self.dispatch_lookup[self.load_balance_method]
64
+
65
+ # Init communication
66
+ context = zmq.asyncio.Context()
67
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
68
+ self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
69
+
70
+ # Init status
71
+ self.recv_reqs = []
72
+
73
+ # Start data parallel workers
74
+ self.workers: Dict[int, DataParallelWorkerThread] = {}
75
+ tp_size = server_args.tp_size
76
+
77
+ def start_dp_worker(i):
78
+ try:
79
+ gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
80
+ worker_thread = start_data_parallel_worker(
81
+ server_args, port_args, model_overide_args, gpu_ids, i
82
+ )
83
+ self.workers[i] = worker_thread
84
+ except Exception:
85
+ logger.error(
86
+ f"Failed to start local worker {i}\n{get_exception_traceback()}"
87
+ )
88
+
89
+ for i in range(server_args.dp_size):
90
+ start_dp_worker(i)
91
+
92
+ # Parallel launch is slower, probably due to the disk bandwidth limitations.
93
+ # with ThreadPoolExecutor(server_args.dp_size) as executor:
94
+ # executor.map(start_dp_worker, range(server_args.dp_size))
95
+
96
+ def have_any_live_worker(self):
97
+ return any(worker_thread.liveness for worker_thread in self.workers.values())
98
+
99
+ def put_req_to_worker(self, worker_id, req):
100
+ self.workers[worker_id].request_queue.put(req)
101
+
102
+ async def round_robin_scheduler(self, input_requests):
103
+ available_workers = list(self.workers.keys())
104
+ for r in input_requests:
105
+ self.put_req_to_worker(available_workers[self.round_robin_counter], r)
106
+ self.round_robin_counter = (self.round_robin_counter + 1) % len(
107
+ available_workers
108
+ )
109
+ return
110
+
111
+ async def shortest_queue_scheduler(self, input_requests):
112
+ for r in input_requests:
113
+ worker = min(
114
+ self.workers, key=lambda w: self.workers[w].request_queue.qsize()
115
+ )
116
+ self.put_req_to_worker(worker, r)
117
+ return
118
+
119
+ async def remove_dead_workers(self):
120
+ for i in list(self.workers.keys()):
121
+ worker_thread = self.workers[i]
122
+ if not worker_thread.liveness:
123
+ worker_thread.join()
124
+ # move unsuccessful requests back to the queue
125
+ while not worker_thread.request_queue.empty():
126
+ self.recv_reqs.append(worker_thread.request_queue.get())
127
+ del self.workers[i]
128
+ logger.info(f"Stale worker {i} removed")
129
+
130
+ async def loop_for_forward(self):
131
+ while True:
132
+ await self.remove_dead_workers()
133
+
134
+ if self.have_any_live_worker():
135
+ next_step_input = list(self.recv_reqs)
136
+ self.recv_reqs = []
137
+ if next_step_input:
138
+ await self.dispatching(next_step_input)
139
+ # else:
140
+ # logger.error("There is no live worker.")
141
+
142
+ await asyncio.sleep(global_config.wait_for_new_request_delay)
143
+
144
+ async def loop_for_recv_requests(self):
145
+ while True:
146
+ recv_req = await self.recv_from_tokenizer.recv_pyobj()
147
+ if isinstance(recv_req, FlushCacheReq):
148
+ # TODO(lsyin): apply more specific flushCacheReq
149
+ for worker_thread in self.workers.values():
150
+ worker_thread.request_queue.put(recv_req)
151
+ elif isinstance(recv_req, TokenizedGenerateReqInput):
152
+ self.recv_reqs.append(recv_req)
153
+ elif isinstance(recv_req, AbortReq):
154
+ in_queue = False
155
+ for i, req in enumerate(self.recv_reqs):
156
+ if req.rid == recv_req.rid:
157
+ self.recv_reqs[i] = recv_req
158
+ in_queue = True
159
+ break
160
+ if not in_queue:
161
+ # Send abort req to all TP groups
162
+ for worker in list(self.workers.keys()):
163
+ self.put_req_to_worker(worker, recv_req)
164
+ else:
165
+ logger.error(f"Invalid object: {recv_req}")
166
+
167
+
168
+ def start_controller_process(
169
+ server_args: ServerArgs,
170
+ port_args: PortArgs,
171
+ pipe_writer,
172
+ model_overide_args=None,
173
+ ):
174
+ logging.basicConfig(
175
+ level=getattr(logging, server_args.log_level.upper()),
176
+ format="%(message)s",
177
+ )
178
+
179
+ try:
180
+ controller = Controller(
181
+ server_args.load_balance_method, server_args, port_args, model_overide_args
182
+ )
183
+ except Exception:
184
+ pipe_writer.send(get_exception_traceback())
185
+ raise
186
+
187
+ pipe_writer.send("init ok")
188
+ loop = asyncio.get_event_loop()
189
+ asyncio.set_event_loop(loop)
190
+ loop.create_task(controller.loop_for_recv_requests())
191
+ loop.run_until_complete(controller.loop_for_forward())