sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,503 +0,0 @@
1
- from dataclasses import dataclass
2
- from enum import Enum, auto
3
- from typing import List
4
-
5
- import numpy as np
6
- import torch
7
- from sglang.srt.managers.router.radix_cache import RadixCache
8
- from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
9
-
10
-
11
- class ForwardMode(Enum):
12
- PREFILL = auto()
13
- EXTEND = auto()
14
- DECODE = auto()
15
-
16
-
17
- class FinishReason(Enum):
18
- LENGTH = auto()
19
- EOS_TOKEN = auto()
20
- STOP_STR = auto()
21
-
22
-
23
- class Req:
24
- def __init__(self, rid, input_text, input_ids):
25
- self.rid = rid
26
- self.input_text = input_text
27
- self.input_ids = input_ids
28
- self.output_ids = []
29
-
30
- # Since jump forward may retokenize the prompt with partial outputs,
31
- # we maintain the original prompt length to report the correct usage.
32
- self.prompt_tokens = len(input_ids)
33
- # The number of decoded tokens for token usage report. Note that
34
- # this does not include the jump forward tokens.
35
- self.completion_tokens_wo_jump_forward = 0
36
-
37
- # For vision input
38
- self.pixel_values = None
39
- self.image_size = None
40
- self.image_offset = 0
41
- self.pad_value = None
42
-
43
- self.sampling_params = None
44
- self.return_logprob = False
45
- self.logprob_start_len = 0
46
- self.stream = False
47
-
48
- self.tokenizer = None
49
- self.finished = False
50
- self.finish_reason = None
51
- self.hit_stop_str = None
52
-
53
- self.extend_input_len = 0
54
- self.prefix_indices = []
55
- self.last_node = None
56
-
57
- self.logprob = None
58
- self.token_logprob = None
59
- self.normalized_logprob = None
60
-
61
- # For constrained decoding
62
- self.regex_fsm = None
63
- self.regex_fsm_state = 0
64
- self.jump_forward_map = None
65
- self.output_and_jump_forward_str = ""
66
-
67
- def max_new_tokens(self):
68
- return self.sampling_params.max_new_tokens
69
-
70
- def jump_forward_and_retokenize(self, jump_forward_str, next_state):
71
- old_output_str = self.tokenizer.decode(self.output_ids)
72
- # FIXME: This logic does not really solve the problem of determining whether
73
- # there should be a leading space.
74
- first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
75
- first_token = (
76
- first_token.decode() if isinstance(first_token, bytes) else first_token
77
- )
78
- if first_token.startswith("▁"):
79
- old_output_str = " " + old_output_str
80
- new_input_string = (
81
- self.input_text
82
- + self.output_and_jump_forward_str
83
- + old_output_str
84
- + jump_forward_str
85
- )
86
- new_input_ids = self.tokenizer.encode(new_input_string)
87
- if self.pixel_values is not None:
88
- # NOTE: This is a hack because the old input_ids contains the image padding
89
- jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
90
- else:
91
- jump_forward_tokens_len = (
92
- len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
93
- )
94
-
95
- # print("=" * 100)
96
- # print(f"Catch jump forward:\n{jump_forward_str}")
97
- # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
98
- # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
99
-
100
- self.input_ids = new_input_ids
101
- self.output_ids = []
102
- self.sampling_params.max_new_tokens = max(
103
- self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
104
- )
105
- self.regex_fsm_state = next_state
106
- self.output_and_jump_forward_str = (
107
- self.output_and_jump_forward_str + old_output_str + jump_forward_str
108
- )
109
-
110
- # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
111
- # print("*" * 100)
112
-
113
- def check_finished(self):
114
- if self.finished:
115
- return
116
-
117
- if len(self.output_ids) >= self.sampling_params.max_new_tokens:
118
- self.finished = True
119
- self.finish_reason = FinishReason.LENGTH
120
- return
121
-
122
- if (
123
- self.output_ids[-1] == self.tokenizer.eos_token_id
124
- and self.sampling_params.ignore_eos == False
125
- ):
126
- self.finished = True
127
- self.finish_reason = FinishReason.EOS_TOKEN
128
- return
129
-
130
- if len(self.sampling_params.stop_strs) > 0:
131
- tail_str = self.tokenizer.decode(
132
- self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
133
- )
134
-
135
- for stop_str in self.sampling_params.stop_strs:
136
- if stop_str in tail_str:
137
- self.finished = True
138
- self.finish_reason = FinishReason.STOP_STR
139
- self.hit_stop_str = stop_str
140
- return
141
-
142
- def __repr__(self):
143
- return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
144
-
145
-
146
- @dataclass
147
- class Batch:
148
- reqs: List[Req]
149
- req_to_token_pool: ReqToTokenPool
150
- token_to_kv_pool: TokenToKVPool
151
- tree_cache: RadixCache
152
-
153
- # batched arguments to model runner
154
- input_ids: torch.Tensor = None
155
- req_pool_indices: torch.Tensor = None
156
- seq_lens: torch.Tensor = None
157
- prefix_lens: torch.Tensor = None
158
- position_ids_offsets: torch.Tensor = None
159
- out_cache_loc: torch.Tensor = None
160
- out_cache_cont_start: torch.Tensor = None
161
- out_cache_cont_end: torch.Tensor = None
162
- return_logprob: bool = False
163
-
164
- # for multimodal
165
- pixel_values: List[torch.Tensor] = None
166
- image_sizes: List[List[int]] = None
167
- image_offsets: List[int] = None
168
-
169
- # other arguments for control
170
- output_ids: torch.Tensor = None
171
- extend_num_tokens: int = None
172
-
173
- # batched sampling params
174
- temperatures: torch.Tensor = None
175
- top_ps: torch.Tensor = None
176
- top_ks: torch.Tensor = None
177
- frequency_penalties: torch.Tensor = None
178
- presence_penalties: torch.Tensor = None
179
- logit_bias: torch.Tensor = None
180
-
181
- @classmethod
182
- def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
183
- return_logprob = any(req.return_logprob for req in reqs)
184
-
185
- return cls(
186
- reqs=reqs,
187
- req_to_token_pool=req_to_token_pool,
188
- token_to_kv_pool=token_to_kv_pool,
189
- tree_cache=tree_cache,
190
- return_logprob=return_logprob,
191
- )
192
-
193
- def is_empty(self):
194
- return len(self.reqs) == 0
195
-
196
- def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
197
- device = "cuda"
198
- bs = len(self.reqs)
199
- reqs = self.reqs
200
- input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
201
- prefix_indices = [r.prefix_indices for r in reqs]
202
-
203
- # Handle prefix
204
- flatten_input_ids = []
205
- extend_lens = []
206
- prefix_lens = []
207
- seq_lens = []
208
-
209
- req_pool_indices = self.req_to_token_pool.alloc(bs)
210
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
211
- for i in range(bs):
212
- flatten_input_ids.extend(input_ids[i])
213
- extend_lens.append(len(input_ids[i]))
214
-
215
- if len(prefix_indices[i]) == 0:
216
- prefix_lens.append(0)
217
- else:
218
- prefix_lens.append(len(prefix_indices[i]))
219
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
220
- : len(prefix_indices[i])
221
- ] = prefix_indices[i]
222
-
223
- seq_lens.append(prefix_lens[-1] + extend_lens[-1])
224
-
225
- position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
226
-
227
- # Alloc mem
228
- seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
229
- extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
230
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
231
- if out_cache_loc is None:
232
- if not self.tree_cache.disable:
233
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
234
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
235
-
236
- if out_cache_loc is None:
237
- print("Prefill out of memory. This should nerver happen.")
238
- self.tree_cache.pretty_print()
239
- exit()
240
-
241
- pt = 0
242
- for i in range(bs):
243
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
244
- prefix_lens[i] : prefix_lens[i] + extend_lens[i]
245
- ] = out_cache_loc[pt : pt + extend_lens[i]]
246
- pt += extend_lens[i]
247
-
248
- # Handle logit bias
249
- logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
250
- for i in range(bs):
251
- if reqs[i].sampling_params.dtype == "int":
252
- logit_bias[i] = int_token_logit_bias
253
-
254
- # Set fields
255
- self.input_ids = torch.tensor(
256
- flatten_input_ids, dtype=torch.int32, device=device
257
- )
258
- self.pixel_values = [r.pixel_values for r in reqs]
259
- self.image_sizes = [r.image_size for r in reqs]
260
- self.image_offsets = [
261
- r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
262
- ]
263
- self.req_pool_indices = req_pool_indices
264
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
265
- self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
266
- self.position_ids_offsets = position_ids_offsets
267
- self.extend_num_tokens = extend_num_tokens
268
- self.out_cache_loc = out_cache_loc
269
-
270
- self.temperatures = torch.tensor(
271
- [r.sampling_params.temperature for r in reqs],
272
- dtype=torch.float,
273
- device=device,
274
- ).view(-1, 1)
275
- self.top_ps = torch.tensor(
276
- [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
277
- ).view(-1, 1)
278
- self.top_ks = torch.tensor(
279
- [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
280
- ).view(-1, 1)
281
- self.frequency_penalties = torch.tensor(
282
- [r.sampling_params.frequency_penalty for r in reqs],
283
- dtype=torch.float,
284
- device=device,
285
- )
286
- self.presence_penalties = torch.tensor(
287
- [r.sampling_params.presence_penalty for r in reqs],
288
- dtype=torch.float,
289
- device=device,
290
- )
291
- self.logit_bias = logit_bias
292
-
293
- def check_decode_mem(self):
294
- bs = len(self.reqs)
295
- if self.token_to_kv_pool.available_size() >= bs:
296
- return True
297
-
298
- if not self.tree_cache.disable:
299
- self.tree_cache.evict(bs, self.token_to_kv_pool.free)
300
- if self.token_to_kv_pool.available_size() >= bs:
301
- return True
302
-
303
- return False
304
-
305
- def retract_decode(self):
306
- sorted_indices = [i for i in range(len(self.reqs))]
307
- sorted_indices.sort(
308
- key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
309
- reverse=True,
310
- )
311
-
312
- retracted_reqs = []
313
- seq_lens_np = self.seq_lens.cpu().numpy()
314
- req_pool_indices_np = self.req_pool_indices.cpu().numpy()
315
- while self.token_to_kv_pool.available_size() < len(self.reqs):
316
- idx = sorted_indices.pop()
317
- req = self.reqs[idx]
318
- retracted_reqs.append(req)
319
-
320
- self.tree_cache.dec_ref_counter(req.last_node)
321
- req.prefix_indices = None
322
- req.last_node = None
323
- req.extend_input_len = 0
324
- req.output_ids = []
325
- req.regex_fsm_state = 0
326
-
327
- # TODO: apply more fine-grained retraction
328
-
329
- token_indices = self.req_to_token_pool.req_to_token[
330
- req_pool_indices_np[idx]
331
- ][: seq_lens_np[idx]]
332
- self.token_to_kv_pool.free(token_indices)
333
-
334
- self.filter_batch(sorted_indices)
335
-
336
- return retracted_reqs
337
-
338
- def check_for_jump_forward(self):
339
- jump_forward_reqs = []
340
- filter_indices = [i for i in range(len(self.reqs))]
341
-
342
- req_pool_indices_cpu = None
343
-
344
- for i, req in enumerate(self.reqs):
345
- if req.jump_forward_map is not None:
346
- res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
347
- if res is not None:
348
- jump_forward_str, next_state = res
349
- if len(jump_forward_str) <= 1:
350
- continue
351
-
352
- # insert the old request into tree_cache
353
- token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
354
- if req_pool_indices_cpu is None:
355
- req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
356
- req_pool_idx = req_pool_indices_cpu[i]
357
- indices = self.req_to_token_pool.req_to_token[
358
- req_pool_idx, : len(token_ids_in_memory)
359
- ]
360
- prefix_len = self.tree_cache.insert(
361
- token_ids_in_memory, indices.clone()
362
- )
363
- self.token_to_kv_pool.free(indices[:prefix_len])
364
- self.req_to_token_pool.free(req_pool_idx)
365
- self.tree_cache.dec_ref_counter(req.last_node)
366
-
367
- # jump-forward
368
- req.jump_forward_and_retokenize(jump_forward_str, next_state)
369
-
370
- jump_forward_reqs.append(req)
371
- filter_indices.remove(i)
372
-
373
- if len(filter_indices) < len(self.reqs):
374
- self.filter_batch(filter_indices)
375
-
376
- return jump_forward_reqs
377
-
378
- def prepare_for_decode(self, input_ids=None):
379
- if input_ids is None:
380
- input_ids = [
381
- r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
382
- ]
383
- self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
384
- self.seq_lens.add_(1)
385
- self.prefix_lens = None
386
-
387
- # Alloc mem
388
- bs = len(self.reqs)
389
- alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
390
- if alloc_res is None:
391
- self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
392
-
393
- if self.out_cache_loc is None:
394
- print("Decode out of memory. This should nerver happen.")
395
- self.tree_cache.pretty_print()
396
- exit()
397
-
398
- self.out_cache_cont_start = None
399
- self.out_cache_cont_end = None
400
- else:
401
- self.out_cache_loc = alloc_res[0]
402
- self.out_cache_cont_start = alloc_res[1]
403
- self.out_cache_cont_end = alloc_res[2]
404
-
405
- self.req_to_token_pool.req_to_token[
406
- self.req_pool_indices, self.seq_lens - 1
407
- ] = self.out_cache_loc
408
-
409
- def filter_batch(self, unfinished_indices: List[int]):
410
- self.reqs = [self.reqs[i] for i in unfinished_indices]
411
- new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
412
- self.seq_lens = self.seq_lens[new_indices]
413
- self.input_ids = None
414
- self.req_pool_indices = self.req_pool_indices[new_indices]
415
- self.prefix_lens = None
416
- self.position_ids_offsets = self.position_ids_offsets[new_indices]
417
- self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
418
- self.return_logprob = any(req.return_logprob for req in self.reqs)
419
-
420
- for item in [
421
- "temperatures",
422
- "top_ps",
423
- "top_ks",
424
- "frequency_penalties",
425
- "presence_penalties",
426
- "logit_bias",
427
- ]:
428
- setattr(self, item, getattr(self, item)[new_indices])
429
-
430
- def merge(self, other):
431
- self.reqs.extend(other.reqs)
432
-
433
- self.req_pool_indices = torch.concat(
434
- [self.req_pool_indices, other.req_pool_indices]
435
- )
436
- self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
437
- self.prefix_lens = None
438
- self.position_ids_offsets = torch.concat(
439
- [self.position_ids_offsets, other.position_ids_offsets]
440
- )
441
- self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
442
- self.return_logprob = any(req.return_logprob for req in self.reqs)
443
-
444
- for item in [
445
- "temperatures",
446
- "top_ps",
447
- "top_ks",
448
- "frequency_penalties",
449
- "presence_penalties",
450
- "logit_bias",
451
- ]:
452
- setattr(
453
- self, item, torch.concat([getattr(self, item), getattr(other, item)])
454
- )
455
-
456
- def sample(self, logits: torch.Tensor):
457
- # Post process logits
458
- logits = logits.contiguous()
459
- logits.div_(self.temperatures)
460
- logits.add_(self.logit_bias)
461
-
462
- has_regex = any(req.regex_fsm is not None for req in self.reqs)
463
- if has_regex:
464
- allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
465
- for i, req in enumerate(self.reqs):
466
- if req.regex_fsm is not None:
467
- allowed_mask.zero_()
468
- allowed_mask[
469
- req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
470
- ] = 1
471
- logits[i].masked_fill_(~allowed_mask, float("-inf"))
472
-
473
- # TODO(lmzheng): apply penalty
474
- probs = torch.softmax(logits, dim=-1)
475
- probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
476
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
477
- batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
478
- -1
479
- )
480
- batch_next_token_probs = torch.gather(
481
- probs_sort, dim=1, index=sampled_index
482
- ).view(-1)
483
-
484
- if has_regex:
485
- batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
486
- for i, req in enumerate(self.reqs):
487
- if req.regex_fsm is not None:
488
- req.regex_fsm_state = req.regex_fsm.next_state(
489
- req.regex_fsm_state, batch_next_token_ids_cpu[i]
490
- )
491
-
492
- return batch_next_token_ids, batch_next_token_probs
493
-
494
-
495
- def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
496
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
497
- probs_sum = torch.cumsum(probs_sort, dim=-1)
498
- probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
499
- probs_sort[
500
- torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
501
- ] = 0.0
502
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
503
- return probs_sort, probs_idx
@@ -1,79 +0,0 @@
1
- import asyncio
2
- import logging
3
-
4
- import uvloop
5
- import zmq
6
- import zmq.asyncio
7
- from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
8
- from sglang.srt.managers.router.model_rpc import ModelRpcClient
9
- from sglang.srt.server_args import PortArgs, ServerArgs
10
- from sglang.srt.utils import get_exception_traceback
11
-
12
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
13
-
14
-
15
- class RouterManager:
16
- def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
17
- # Init communication
18
- context = zmq.asyncio.Context(2)
19
- self.recv_from_tokenizer = context.socket(zmq.PULL)
20
- self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
21
-
22
- self.send_to_detokenizer = context.socket(zmq.PUSH)
23
- self.send_to_detokenizer.connect(
24
- f"tcp://127.0.0.1:{port_args.detokenizer_port}"
25
- )
26
-
27
- # Init status
28
- self.model_client = model_client
29
- self.recv_reqs = []
30
-
31
- # Init some configs
32
- self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
33
-
34
- async def loop_for_forward(self):
35
- while True:
36
- next_step_input = list(self.recv_reqs)
37
- self.recv_reqs = []
38
- out_pyobjs = await self.model_client.step(next_step_input)
39
-
40
- for obj in out_pyobjs:
41
- self.send_to_detokenizer.send_pyobj(obj)
42
-
43
- # async sleep for receiving the subsequent request and avoiding cache miss
44
- if len(out_pyobjs) != 0:
45
- has_finished = any([obj.finished for obj in out_pyobjs])
46
- if has_finished:
47
- await asyncio.sleep(self.extend_dependency_time)
48
-
49
- await asyncio.sleep(0.0006)
50
-
51
- async def loop_for_recv_requests(self):
52
- while True:
53
- recv_req = await self.recv_from_tokenizer.recv_pyobj()
54
- self.recv_reqs.append(recv_req)
55
-
56
-
57
- def start_router_process(
58
- server_args: ServerArgs,
59
- port_args: PortArgs,
60
- pipe_writer,
61
- ):
62
- logging.basicConfig(
63
- level=getattr(logging, server_args.log_level.upper()),
64
- format="%(message)s",
65
- )
66
-
67
- try:
68
- model_client = ModelRpcClient(server_args, port_args)
69
- router = RouterManager(model_client, port_args)
70
- except Exception:
71
- pipe_writer.send(get_exception_traceback())
72
- raise
73
-
74
- pipe_writer.send("init ok")
75
-
76
- loop = asyncio.new_event_loop()
77
- asyncio.set_event_loop(loop)
78
- loop.create_task(router.loop_for_recv_requests())
79
- loop.run_until_complete(router.loop_for_forward())