sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,596 +0,0 @@
1
- from dataclasses import dataclass
2
- from enum import IntEnum, auto
3
- from typing import List
4
-
5
- import numpy as np
6
- import torch
7
-
8
- from sglang.srt.managers.router.radix_cache import RadixCache
9
- from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
10
-
11
-
12
- class ForwardMode(IntEnum):
13
- PREFILL = auto()
14
- EXTEND = auto()
15
- DECODE = auto()
16
-
17
-
18
- class FinishReason(IntEnum):
19
- EOS_TOKEN = auto()
20
- LENGTH = auto()
21
- STOP_STR = auto()
22
- ABORT = auto()
23
-
24
- @staticmethod
25
- def to_str(reason):
26
- if reason == FinishReason.EOS_TOKEN:
27
- return None
28
- elif reason == FinishReason.LENGTH:
29
- return "length"
30
- elif reason == FinishReason.STOP_STR:
31
- return "stop"
32
- elif reason == FinishReason.ABORT:
33
- return "abort"
34
- else:
35
- return None
36
-
37
-
38
- class Req:
39
- def __init__(self, rid, origin_input_text, origin_input_ids):
40
- self.rid = rid
41
- self.origin_input_text = origin_input_text
42
- self.origin_input_ids = origin_input_ids
43
- self.origin_input_ids_unpadded = origin_input_ids # before image padding
44
- self.prev_output_str = ""
45
- self.prev_output_ids = []
46
- self.output_ids = []
47
- self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
48
-
49
- # The number of decoded tokens for token usage report. Note that
50
- # this does not include the jump forward tokens.
51
- self.completion_tokens_wo_jump_forward = 0
52
-
53
- # For vision input
54
- self.pixel_values = None
55
- self.image_size = None
56
- self.image_offset = 0
57
- self.pad_value = None
58
-
59
- # Sampling parameters
60
- self.sampling_params = None
61
- self.stream = False
62
-
63
- # Check finish
64
- self.tokenizer = None
65
- self.finished = False
66
- self.finish_reason = None
67
- self.hit_stop_str = None
68
-
69
- # Prefix info
70
- self.extend_input_len = 0
71
- self.prefix_indices = []
72
- self.last_node = None
73
-
74
- # Logprobs
75
- self.return_logprob = False
76
- self.logprob_start_len = 0
77
- self.top_logprobs_num = 0
78
- self.normalized_prompt_logprob = None
79
- self.prefill_token_logprobs = None
80
- self.prefill_top_logprobs = None
81
- self.decode_token_logprobs = []
82
- self.decode_top_logprobs = []
83
- # The tokens is prefilled but need to be considered as decode tokens
84
- # and should be updated for the decode logprobs
85
- self.last_update_decode_tokens = 0
86
-
87
- # Constrained decoding
88
- self.regex_fsm = None
89
- self.regex_fsm_state = 0
90
- self.jump_forward_map = None
91
-
92
- def partial_decode(self, ids):
93
- first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
94
- first_token = (
95
- first_token.decode() if isinstance(first_token, bytes) else first_token
96
- )
97
- return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
98
-
99
- def max_new_tokens(self):
100
- return self.sampling_params.max_new_tokens
101
-
102
- def check_finished(self):
103
- if self.finished:
104
- return
105
-
106
- if (
107
- len(self.prev_output_ids) + len(self.output_ids)
108
- >= self.sampling_params.max_new_tokens
109
- ):
110
- self.finished = True
111
- self.finish_reason = FinishReason.LENGTH
112
- return
113
-
114
- if (
115
- self.output_ids[-1] == self.tokenizer.eos_token_id
116
- and self.sampling_params.ignore_eos == False
117
- ):
118
- self.finished = True
119
- self.finish_reason = FinishReason.EOS_TOKEN
120
- return
121
-
122
- if len(self.sampling_params.stop_strs) > 0:
123
- tail_str = self.tokenizer.decode(
124
- self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
125
- )
126
-
127
- for stop_str in self.sampling_params.stop_strs:
128
- # FIXME: (minor) try incremental match in prev_output_str
129
- if stop_str in tail_str or stop_str in self.prev_output_str:
130
- self.finished = True
131
- self.finish_reason = FinishReason.STOP_STR
132
- self.hit_stop_str = stop_str
133
- return
134
-
135
- def jump_forward_and_retokenize(self, jump_forward_str, next_state):
136
- # FIXME: This logic does not really solve the problem of determining whether
137
- # there should be a leading space.
138
- cur_output_str = self.partial_decode(self.output_ids)
139
-
140
- # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
141
- if self.origin_input_text is None:
142
- # Recovering text can only use unpadded ids
143
- self.origin_input_text = self.tokenizer.decode(
144
- self.origin_input_ids_unpadded
145
- )
146
-
147
- all_text = (
148
- self.origin_input_text
149
- + self.prev_output_str
150
- + cur_output_str
151
- + jump_forward_str
152
- )
153
- all_ids = self.tokenizer.encode(all_text)
154
- prompt_tokens = len(self.origin_input_ids_unpadded)
155
- self.origin_input_ids = all_ids[:prompt_tokens]
156
- self.origin_input_ids_unpadded = self.origin_input_ids
157
- # NOTE: the output ids may not strictly correspond to the output text
158
- old_prev_output_ids = self.prev_output_ids
159
- self.prev_output_ids = all_ids[prompt_tokens:]
160
- self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
161
- self.output_ids = []
162
-
163
- self.regex_fsm_state = next_state
164
-
165
- if self.return_logprob:
166
- # For fast-forward part's logprobs
167
- k = 0
168
- for i, old_id in enumerate(old_prev_output_ids):
169
- if old_id == self.prev_output_ids[i]:
170
- k = k + 1
171
- else:
172
- break
173
- self.decode_token_logprobs = self.decode_token_logprobs[:k]
174
- self.decode_top_logprobs = self.decode_top_logprobs[:k]
175
- self.logprob_start_len = prompt_tokens + k
176
- self.last_update_decode_tokens = len(self.prev_output_ids) - k
177
-
178
- # print("=" * 100)
179
- # print(f"Catch jump forward:\n{jump_forward_str}")
180
- # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
181
- # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
182
-
183
- # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
184
- # print("*" * 100)
185
-
186
- def __repr__(self):
187
- return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
188
-
189
-
190
- @dataclass
191
- class Batch:
192
- reqs: List[Req]
193
- req_to_token_pool: ReqToTokenPool
194
- token_to_kv_pool: TokenToKVPool
195
- tree_cache: RadixCache
196
-
197
- # batched arguments to model runner
198
- input_ids: torch.Tensor = None
199
- req_pool_indices: torch.Tensor = None
200
- seq_lens: torch.Tensor = None
201
- prefix_lens: torch.Tensor = None
202
- position_ids_offsets: torch.Tensor = None
203
- out_cache_loc: torch.Tensor = None
204
- out_cache_cont_start: torch.Tensor = None
205
- out_cache_cont_end: torch.Tensor = None
206
-
207
- # for processing logprobs
208
- return_logprob: bool = False
209
- top_logprobs_nums: List[int] = None
210
-
211
- # for multimodal
212
- pixel_values: List[torch.Tensor] = None
213
- image_sizes: List[List[int]] = None
214
- image_offsets: List[int] = None
215
-
216
- # other arguments for control
217
- output_ids: torch.Tensor = None
218
- extend_num_tokens: int = None
219
-
220
- # batched sampling params
221
- temperatures: torch.Tensor = None
222
- top_ps: torch.Tensor = None
223
- top_ks: torch.Tensor = None
224
- frequency_penalties: torch.Tensor = None
225
- presence_penalties: torch.Tensor = None
226
- logit_bias: torch.Tensor = None
227
-
228
- @classmethod
229
- def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
230
- return_logprob = any(req.return_logprob for req in reqs)
231
-
232
- return cls(
233
- reqs=reqs,
234
- req_to_token_pool=req_to_token_pool,
235
- token_to_kv_pool=token_to_kv_pool,
236
- tree_cache=tree_cache,
237
- return_logprob=return_logprob,
238
- )
239
-
240
- def is_empty(self):
241
- return len(self.reqs) == 0
242
-
243
- def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
244
- device = "cuda"
245
- bs = len(self.reqs)
246
- reqs = self.reqs
247
- input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
248
- prefix_indices = [r.prefix_indices for r in reqs]
249
-
250
- # Handle prefix
251
- flatten_input_ids = []
252
- extend_lens = []
253
- prefix_lens = []
254
- seq_lens = []
255
-
256
- req_pool_indices = self.req_to_token_pool.alloc(bs)
257
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
258
- for i in range(bs):
259
- flatten_input_ids.extend(input_ids[i])
260
- extend_lens.append(len(input_ids[i]))
261
-
262
- if len(prefix_indices[i]) == 0:
263
- prefix_lens.append(0)
264
- else:
265
- prefix_lens.append(len(prefix_indices[i]))
266
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
267
- : len(prefix_indices[i])
268
- ] = prefix_indices[i]
269
-
270
- seq_lens.append(prefix_lens[-1] + extend_lens[-1])
271
-
272
- position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
273
-
274
- # Alloc mem
275
- seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
276
- extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
277
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
278
- if out_cache_loc is None:
279
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
280
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
281
-
282
- if out_cache_loc is None:
283
- print("Prefill out of memory. This should never happen.")
284
- self.tree_cache.pretty_print()
285
- exit()
286
-
287
- pt = 0
288
- for i in range(bs):
289
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
290
- prefix_lens[i] : prefix_lens[i] + extend_lens[i]
291
- ] = out_cache_loc[pt : pt + extend_lens[i]]
292
- pt += extend_lens[i]
293
-
294
- # Handle logit bias but only allocate when needed
295
- logit_bias = None
296
- for i in range(bs):
297
- if reqs[i].sampling_params.dtype == "int":
298
- if logit_bias is None:
299
- logit_bias = torch.zeros(
300
- (bs, vocab_size), dtype=torch.float32, device=device
301
- )
302
- logit_bias[i] = int_token_logit_bias
303
-
304
- # Set fields
305
- self.input_ids = torch.tensor(
306
- flatten_input_ids, dtype=torch.int32, device=device
307
- )
308
- self.pixel_values = [r.pixel_values for r in reqs]
309
- self.image_sizes = [r.image_size for r in reqs]
310
- self.image_offsets = [
311
- r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
312
- ]
313
- self.req_pool_indices = req_pool_indices
314
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
315
- self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
316
- self.position_ids_offsets = position_ids_offsets
317
- self.extend_num_tokens = extend_num_tokens
318
- self.out_cache_loc = out_cache_loc
319
- self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
320
-
321
- self.temperatures = torch.tensor(
322
- [r.sampling_params.temperature for r in reqs],
323
- dtype=torch.float,
324
- device=device,
325
- ).view(-1, 1)
326
- self.top_ps = torch.tensor(
327
- [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
328
- ).view(-1, 1)
329
- self.top_ks = torch.tensor(
330
- [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
331
- ).view(-1, 1)
332
- self.frequency_penalties = torch.tensor(
333
- [r.sampling_params.frequency_penalty for r in reqs],
334
- dtype=torch.float,
335
- device=device,
336
- )
337
- self.presence_penalties = torch.tensor(
338
- [r.sampling_params.presence_penalty for r in reqs],
339
- dtype=torch.float,
340
- device=device,
341
- )
342
- self.logit_bias = logit_bias
343
-
344
- def check_decode_mem(self):
345
- bs = len(self.reqs)
346
- if self.token_to_kv_pool.available_size() >= bs:
347
- return True
348
-
349
- self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
350
-
351
- if self.token_to_kv_pool.available_size() >= bs:
352
- return True
353
-
354
- return False
355
-
356
- def retract_decode(self):
357
- sorted_indices = [i for i in range(len(self.reqs))]
358
- # TODO(lsyin): improve the priority of retraction
359
- sorted_indices.sort(
360
- key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
361
- reverse=True,
362
- )
363
-
364
- retracted_reqs = []
365
- seq_lens_cpu = self.seq_lens.cpu().numpy()
366
- req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
367
- while self.token_to_kv_pool.available_size() < len(self.reqs):
368
- idx = sorted_indices.pop()
369
- req = self.reqs[idx]
370
- retracted_reqs.append(req)
371
-
372
- # TODO: apply more fine-grained retraction
373
- last_uncached_pos = len(req.prefix_indices)
374
- token_indices = self.req_to_token_pool.req_to_token[
375
- req_pool_indices_cpu[idx]
376
- ][last_uncached_pos : seq_lens_cpu[idx]]
377
- self.token_to_kv_pool.dec_refs(token_indices)
378
-
379
- # release the last node
380
- self.tree_cache.dec_lock_ref(req.last_node)
381
-
382
- cur_output_str = req.partial_decode(req.output_ids)
383
- req.prev_output_str = req.prev_output_str + cur_output_str
384
- req.prev_output_ids.extend(req.output_ids)
385
-
386
- req.prefix_indices = None
387
- req.last_node = None
388
- req.extend_input_len = 0
389
- req.output_ids = []
390
-
391
- # For incremental logprobs
392
- req.last_update_decode_tokens = 0
393
- req.logprob_start_len = 10**9
394
-
395
- self.filter_batch(sorted_indices)
396
-
397
- return retracted_reqs
398
-
399
- def check_for_jump_forward(self, model_runner):
400
- jump_forward_reqs = []
401
- filter_indices = [i for i in range(len(self.reqs))]
402
-
403
- req_pool_indices_cpu = None
404
-
405
- for i, req in enumerate(self.reqs):
406
- if req.jump_forward_map is not None:
407
- res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
408
- if res is not None:
409
- jump_forward_str, next_state = res
410
- if len(jump_forward_str) <= 1:
411
- continue
412
-
413
- if req_pool_indices_cpu is None:
414
- req_pool_indices_cpu = self.req_pool_indices.tolist()
415
-
416
- # insert the old request into tree_cache
417
- self.tree_cache.cache_req(
418
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
419
- last_uncached_pos=len(req.prefix_indices),
420
- req_pool_idx=req_pool_indices_cpu[i],
421
- )
422
-
423
- # unlock the last node
424
- self.tree_cache.dec_lock_ref(req.last_node)
425
-
426
- # jump-forward
427
- req.jump_forward_and_retokenize(jump_forward_str, next_state)
428
-
429
- # re-applying image padding
430
- if req.pixel_values is not None:
431
- (
432
- req.origin_input_ids,
433
- req.image_offset,
434
- ) = model_runner.model.pad_input_ids(
435
- req.origin_input_ids_unpadded,
436
- req.pad_value,
437
- req.pixel_values.shape,
438
- req.image_size,
439
- )
440
-
441
- jump_forward_reqs.append(req)
442
- filter_indices.remove(i)
443
-
444
- if len(filter_indices) < len(self.reqs):
445
- self.filter_batch(filter_indices)
446
-
447
- return jump_forward_reqs
448
-
449
- def prepare_for_decode(self, input_ids=None):
450
- if input_ids is None:
451
- input_ids = [
452
- r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
453
- ]
454
- self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
455
- self.seq_lens.add_(1)
456
- self.prefix_lens = None
457
-
458
- # Alloc mem
459
- bs = len(self.reqs)
460
- alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
461
- if alloc_res is None:
462
- self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
463
-
464
- if self.out_cache_loc is None:
465
- print("Decode out of memory. This should never happen.")
466
- self.tree_cache.pretty_print()
467
- exit()
468
-
469
- self.out_cache_cont_start = None
470
- self.out_cache_cont_end = None
471
- else:
472
- self.out_cache_loc = alloc_res[0]
473
- self.out_cache_cont_start = alloc_res[1]
474
- self.out_cache_cont_end = alloc_res[2]
475
-
476
- self.req_to_token_pool.req_to_token[
477
- self.req_pool_indices, self.seq_lens - 1
478
- ] = self.out_cache_loc
479
-
480
- def filter_batch(self, unfinished_indices: List[int]):
481
- self.reqs = [self.reqs[i] for i in unfinished_indices]
482
- new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
483
- self.seq_lens = self.seq_lens[new_indices]
484
- self.input_ids = None
485
- self.req_pool_indices = self.req_pool_indices[new_indices]
486
- self.prefix_lens = None
487
- self.position_ids_offsets = self.position_ids_offsets[new_indices]
488
- self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
489
- self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
490
- self.return_logprob = any(req.return_logprob for req in self.reqs)
491
-
492
- for item in [
493
- "temperatures",
494
- "top_ps",
495
- "top_ks",
496
- "frequency_penalties",
497
- "presence_penalties",
498
- "logit_bias",
499
- ]:
500
- self_val = getattr(self, item, None)
501
- # logit_bias can be None
502
- if self_val is not None:
503
- setattr(self, item, self_val[new_indices])
504
-
505
- def merge(self, other: "Batch"):
506
- self.reqs.extend(other.reqs)
507
-
508
- self.req_pool_indices = torch.concat(
509
- [self.req_pool_indices, other.req_pool_indices]
510
- )
511
- self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
512
- self.prefix_lens = None
513
- self.position_ids_offsets = torch.concat(
514
- [self.position_ids_offsets, other.position_ids_offsets]
515
- )
516
- self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
517
- self.top_logprobs_nums.extend(other.top_logprobs_nums)
518
- self.return_logprob = any(req.return_logprob for req in self.reqs)
519
-
520
- for item in [
521
- "temperatures",
522
- "top_ps",
523
- "top_ks",
524
- "frequency_penalties",
525
- "presence_penalties",
526
- ]:
527
- self_val = getattr(self, item, None)
528
- other_val = getattr(other, item, None)
529
- setattr(self, item, torch.concat([self_val, other_val]))
530
-
531
- # logit_bias can be None
532
- if self.logit_bias is not None or other.logit_bias is not None:
533
- vocab_size = (
534
- self.logit_bias.shape[1]
535
- if self.logit_bias is not None
536
- else other.logit_bias.shape[1]
537
- )
538
- if self.logit_bias is None:
539
- self.logit_bias = torch.zeros(
540
- (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
541
- )
542
- if other.logit_bias is None:
543
- other.logit_bias = torch.zeros(
544
- (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
545
- )
546
- self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
547
-
548
- def sample(self, logits: torch.Tensor):
549
- # Post process logits
550
- logits = logits.contiguous()
551
- logits.div_(self.temperatures)
552
- if self.logit_bias is not None:
553
- logits.add_(self.logit_bias)
554
-
555
- has_regex = any(req.regex_fsm is not None for req in self.reqs)
556
- if has_regex:
557
- allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
558
- for i, req in enumerate(self.reqs):
559
- if req.regex_fsm is not None:
560
- allowed_mask.zero_()
561
- allowed_mask[
562
- req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
563
- ] = 1
564
- logits[i].masked_fill_(~allowed_mask, float("-inf"))
565
-
566
- # TODO(lmzheng): apply penalty
567
- probs = torch.softmax(logits, dim=-1)
568
- probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
569
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
570
- batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
571
- -1
572
- )
573
- batch_next_token_probs = torch.gather(
574
- probs_sort, dim=1, index=sampled_index
575
- ).view(-1)
576
-
577
- if has_regex:
578
- batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
579
- for i, req in enumerate(self.reqs):
580
- if req.regex_fsm is not None:
581
- req.regex_fsm_state = req.regex_fsm.next_state(
582
- req.regex_fsm_state, batch_next_token_ids_cpu[i]
583
- )
584
-
585
- return batch_next_token_ids, batch_next_token_probs
586
-
587
-
588
- def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
589
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
590
- probs_sum = torch.cumsum(probs_sort, dim=-1)
591
- probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
592
- probs_sort[
593
- torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
594
- ] = 0.0
595
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
596
- return probs_sort, probs_idx
@@ -1,82 +0,0 @@
1
- import asyncio
2
- import logging
3
-
4
- import uvloop
5
- import zmq
6
- import zmq.asyncio
7
-
8
- from sglang.global_config import global_config
9
- from sglang.srt.managers.router.model_rpc import ModelRpcClient
10
- from sglang.srt.server_args import PortArgs, ServerArgs
11
- from sglang.utils import get_exception_traceback
12
-
13
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
14
-
15
-
16
- class RouterManager:
17
- def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
18
- # Init communication
19
- context = zmq.asyncio.Context(2)
20
- self.recv_from_tokenizer = context.socket(zmq.PULL)
21
- self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
22
-
23
- self.send_to_detokenizer = context.socket(zmq.PUSH)
24
- self.send_to_detokenizer.connect(
25
- f"tcp://127.0.0.1:{port_args.detokenizer_port}"
26
- )
27
-
28
- # Init status
29
- self.model_client = model_client
30
- self.recv_reqs = []
31
-
32
- # Init some configs
33
- self.request_dependency_time = global_config.request_dependency_time
34
-
35
- async def loop_for_forward(self):
36
- while True:
37
- next_step_input = list(self.recv_reqs)
38
- self.recv_reqs = []
39
- out_pyobjs = await self.model_client.step(next_step_input)
40
-
41
- for obj in out_pyobjs:
42
- self.send_to_detokenizer.send_pyobj(obj)
43
-
44
- # async sleep for receiving the subsequent request and avoiding cache miss
45
- slept = False
46
- if len(out_pyobjs) != 0:
47
- has_finished = any([obj.finished for obj in out_pyobjs])
48
- if has_finished:
49
- if self.request_dependency_time > 0:
50
- slept = True
51
- await asyncio.sleep(self.request_dependency_time)
52
-
53
- if not slept:
54
- await asyncio.sleep(0.0006)
55
-
56
- async def loop_for_recv_requests(self):
57
- while True:
58
- recv_req = await self.recv_from_tokenizer.recv_pyobj()
59
- self.recv_reqs.append(recv_req)
60
-
61
-
62
- def start_router_process(
63
- server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
64
- ):
65
- logging.basicConfig(
66
- level=getattr(logging, server_args.log_level.upper()),
67
- format="%(message)s",
68
- )
69
-
70
- try:
71
- model_client = ModelRpcClient(server_args, port_args, model_overide_args)
72
- router = RouterManager(model_client, port_args)
73
- except Exception:
74
- pipe_writer.send(get_exception_traceback())
75
- raise
76
-
77
- pipe_writer.send("init ok")
78
-
79
- loop = asyncio.new_event_loop()
80
- asyncio.set_event_loop(loop)
81
- loop.create_task(router.loop_for_recv_requests())
82
- loop.run_until_complete(router.loop_for_forward())