sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,908 @@
1
+ """Meta data for requests and batches"""
2
+
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from enum import IntEnum, auto
6
+ from typing import List, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from sglang.srt.constrained import RegexGuide
12
+ from sglang.srt.constrained.jump_forward import JumpForwardMap
13
+ from sglang.srt.managers.controller.radix_cache import RadixCache
14
+ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
15
+
16
+ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17
+
18
+ # Store some global server args
19
+ global_server_args_dict = {}
20
+
21
+
22
+ class ForwardMode(IntEnum):
23
+ # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
24
+ PREFILL = auto()
25
+ # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
26
+ EXTEND = auto()
27
+ # Decode one token.
28
+ DECODE = auto()
29
+
30
+
31
+ class BaseFinishReason:
32
+ def __init__(self, is_error: bool = False):
33
+ self.is_error = is_error
34
+
35
+ def __str__(self):
36
+ raise NotImplementedError("Subclasses must implement this method")
37
+
38
+
39
+ class FINISH_MATCHED_TOKEN(BaseFinishReason):
40
+ def __init__(self, matched: Union[int, List[int]]):
41
+ super().__init__()
42
+ self.matched = matched
43
+
44
+ def __str__(self) -> str:
45
+ return f"FINISH_MATCHED_TOKEN: {self.matched}"
46
+
47
+
48
+ class FINISH_LENGTH(BaseFinishReason):
49
+ def __init__(self, length: int):
50
+ super().__init__()
51
+ self.length = length
52
+
53
+ def __str__(self) -> str:
54
+ return f"FINISH_LENGTH: {self.length}"
55
+
56
+
57
+ class FINISH_MATCHED_STR(BaseFinishReason):
58
+ def __init__(self, matched: str):
59
+ super().__init__()
60
+ self.matched = matched
61
+
62
+ def __str__(self) -> str:
63
+ return f"FINISH_MATCHED_STR: {self.matched}"
64
+
65
+
66
+ class FINISH_ABORT(BaseFinishReason):
67
+ def __init__(self):
68
+ super().__init__(is_error=True)
69
+
70
+ def __str__(self) -> str:
71
+ return "FINISH_ABORT"
72
+
73
+
74
+ class Req:
75
+ """Store all inforamtion of a request."""
76
+
77
+ def __init__(self, rid, origin_input_text, origin_input_ids):
78
+ # Input and output info
79
+ self.rid = rid
80
+ self.origin_input_text = origin_input_text
81
+ self.origin_input_ids_unpadded = origin_input_ids # Before image padding
82
+ self.origin_input_ids = origin_input_ids
83
+ self.output_ids = [] # Each decode stage's output ids
84
+ self.input_ids = None # input_ids = origin_input_ids + output_ids
85
+
86
+ # For incremental decoding
87
+ self.decoded_text = ""
88
+ self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
89
+ self.read_offset = None
90
+
91
+ # The number of decoded tokens for token usage report. Note that
92
+ # this does not include the jump forward tokens.
93
+ self.completion_tokens_wo_jump_forward = 0
94
+
95
+ # For vision input
96
+ self.pixel_values = None
97
+ self.image_size = None
98
+ self.image_offset = 0
99
+ self.pad_value = None
100
+
101
+ # Prefix info
102
+ self.extend_input_len = 0
103
+ self.prefix_indices = []
104
+ self.last_node = None
105
+
106
+ # Sampling parameters
107
+ self.sampling_params = None
108
+ self.stream = False
109
+
110
+ # Check finish
111
+ self.tokenizer = None
112
+ self.finished_reason = None
113
+
114
+ # Logprobs
115
+ self.return_logprob = False
116
+ self.logprob_start_len = 0
117
+ self.top_logprobs_num = 0
118
+ self.normalized_prompt_logprob = None
119
+ self.prefill_token_logprobs = None
120
+ self.prefill_top_logprobs = None
121
+ self.decode_token_logprobs = []
122
+ self.decode_top_logprobs = []
123
+ # The tokens is prefilled but need to be considered as decode tokens
124
+ # and should be updated for the decode logprobs
125
+ self.last_update_decode_tokens = 0
126
+
127
+ # Constrained decoding
128
+ self.regex_fsm: RegexGuide = None
129
+ self.regex_fsm_state: int = 0
130
+ self.jump_forward_map: JumpForwardMap = None
131
+
132
+ # whether request reached finished condition
133
+ def finished(self) -> bool:
134
+ return self.finished_reason is not None
135
+
136
+ # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
137
+ def init_detokenize_incrementally(self):
138
+ first_iter = self.surr_offset is None or self.read_offset is None
139
+
140
+ if first_iter:
141
+ self.read_offset = len(self.origin_input_ids_unpadded)
142
+ self.surr_offset = max(
143
+ self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
144
+ )
145
+
146
+ all_ids = self.origin_input_ids_unpadded + self.output_ids
147
+ surr_ids = all_ids[self.surr_offset : self.read_offset]
148
+ read_ids = all_ids[self.surr_offset :]
149
+
150
+ return surr_ids, read_ids, len(all_ids)
151
+
152
+ def detokenize_incrementally(self, inplace: bool = True):
153
+ surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
154
+
155
+ surr_text = self.tokenizer.decode(
156
+ surr_ids,
157
+ skip_special_tokens=self.sampling_params.skip_special_tokens,
158
+ spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
159
+ )
160
+ new_text = self.tokenizer.decode(
161
+ read_ids,
162
+ skip_special_tokens=self.sampling_params.skip_special_tokens,
163
+ spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
164
+ )
165
+
166
+ if len(new_text) > len(surr_text) and not new_text.endswith("�"):
167
+ new_text = new_text[len(surr_text) :]
168
+ if inplace:
169
+ self.decoded_text += new_text
170
+ self.surr_offset = self.read_offset
171
+ self.read_offset = num_all_tokens
172
+
173
+ return True, new_text
174
+
175
+ return False, ""
176
+
177
+ def check_finished(self):
178
+ if self.finished():
179
+ return
180
+
181
+ if len(self.output_ids) >= self.sampling_params.max_new_tokens:
182
+ self.finished_reason = FINISH_LENGTH(len(self.output_ids))
183
+ return
184
+
185
+ if (
186
+ self.output_ids[-1] == self.tokenizer.eos_token_id
187
+ and not self.sampling_params.ignore_eos
188
+ ):
189
+ self.finished_reason = FINISH_MATCHED_TOKEN(
190
+ matched=self.tokenizer.eos_token_id
191
+ )
192
+ return
193
+
194
+ if len(self.sampling_params.stop_strs) > 0:
195
+ tail_str = self.tokenizer.decode(
196
+ self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
197
+ )
198
+
199
+ for stop_str in self.sampling_params.stop_strs:
200
+ if stop_str in tail_str or stop_str in self.decoded_text:
201
+ self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
202
+ return
203
+
204
+ def jump_forward_and_retokenize(self, jump_forward_str, next_state):
205
+ if self.origin_input_text is None:
206
+ # Recovering text can only use unpadded ids
207
+ self.origin_input_text = self.tokenizer.decode(
208
+ self.origin_input_ids_unpadded
209
+ )
210
+
211
+ all_text = self.origin_input_text + self.decoded_text + jump_forward_str
212
+ all_ids = self.tokenizer.encode(all_text)
213
+ prompt_tokens = len(self.origin_input_ids_unpadded)
214
+
215
+ if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
216
+ # TODO(lsyin): fix token fusion
217
+ warnings.warn(
218
+ "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
219
+ )
220
+ return False
221
+
222
+ old_output_ids = self.output_ids
223
+ self.output_ids = all_ids[prompt_tokens:]
224
+ self.decoded_text = self.decoded_text + jump_forward_str
225
+ self.surr_offset = prompt_tokens
226
+ self.read_offset = len(all_ids)
227
+
228
+ # NOTE: A trick to reduce the surrouding tokens decoding overhead
229
+ for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
230
+ surr_text_ = self.tokenizer.decode(
231
+ all_ids[self.read_offset - i : self.read_offset]
232
+ )
233
+ if not surr_text_.endswith("�"):
234
+ self.surr_offset = self.read_offset - i
235
+ break
236
+
237
+ self.regex_fsm_state = next_state
238
+
239
+ if self.return_logprob:
240
+ # For fast-forward part's logprobs
241
+ k = 0
242
+ for i, old_id in enumerate(old_output_ids):
243
+ if old_id == self.output_ids[i]:
244
+ k = k + 1
245
+ else:
246
+ break
247
+ self.decode_token_logprobs = self.decode_token_logprobs[:k]
248
+ self.decode_top_logprobs = self.decode_top_logprobs[:k]
249
+ self.logprob_start_len = prompt_tokens + k
250
+ self.last_update_decode_tokens = len(self.output_ids) - k
251
+
252
+ return True
253
+
254
+ def __repr__(self):
255
+ return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
256
+
257
+
258
+ @dataclass
259
+ class Batch:
260
+ """Store all inforamtion of a batch."""
261
+
262
+ # Request, memory pool, and cache
263
+ reqs: List[Req]
264
+ req_to_token_pool: ReqToTokenPool
265
+ token_to_kv_pool: TokenToKVPool
266
+ tree_cache: RadixCache
267
+
268
+ # Batched arguments to model runner
269
+ input_ids: torch.Tensor = None
270
+ req_pool_indices: torch.Tensor = None
271
+ seq_lens: torch.Tensor = None
272
+ prefix_lens: torch.Tensor = None
273
+ position_ids_offsets: torch.Tensor = None
274
+ out_cache_loc: torch.Tensor = None
275
+
276
+ # For processing logprobs
277
+ return_logprob: bool = False
278
+ top_logprobs_nums: List[int] = None
279
+
280
+ # For multimodal
281
+ pixel_values: List[torch.Tensor] = None
282
+ image_sizes: List[List[int]] = None
283
+ image_offsets: List[int] = None
284
+
285
+ # Other arguments for control
286
+ output_ids: torch.Tensor = None
287
+ extend_num_tokens: int = None
288
+
289
+ # Batched sampling params
290
+ temperatures: torch.Tensor = None
291
+ top_ps: torch.Tensor = None
292
+ top_ks: torch.Tensor = None
293
+ frequency_penalties: torch.Tensor = None
294
+ presence_penalties: torch.Tensor = None
295
+ logit_bias: torch.Tensor = None
296
+
297
+ @classmethod
298
+ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
299
+ return_logprob = any(req.return_logprob for req in reqs)
300
+
301
+ return cls(
302
+ reqs=reqs,
303
+ req_to_token_pool=req_to_token_pool,
304
+ token_to_kv_pool=token_to_kv_pool,
305
+ tree_cache=tree_cache,
306
+ return_logprob=return_logprob,
307
+ )
308
+
309
+ def is_empty(self):
310
+ return len(self.reqs) == 0
311
+
312
+ def has_stream(self) -> bool:
313
+ # Return whether batch has at least 1 streaming request
314
+ return any(r.stream for r in self.reqs)
315
+
316
+ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
317
+ device = "cuda"
318
+ bs = len(self.reqs)
319
+ reqs = self.reqs
320
+ input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
321
+ prefix_indices = [r.prefix_indices for r in reqs]
322
+
323
+ # Handle prefix
324
+ flatten_input_ids = []
325
+ extend_lens = []
326
+ prefix_lens = []
327
+ seq_lens = []
328
+
329
+ req_pool_indices = self.req_to_token_pool.alloc(bs)
330
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
331
+ for i in range(bs):
332
+ flatten_input_ids.extend(input_ids[i])
333
+ extend_lens.append(len(input_ids[i]))
334
+
335
+ if len(prefix_indices[i]) == 0:
336
+ prefix_lens.append(0)
337
+ else:
338
+ prefix_lens.append(len(prefix_indices[i]))
339
+ self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
340
+ : len(prefix_indices[i])
341
+ ] = prefix_indices[i]
342
+
343
+ seq_lens.append(prefix_lens[-1] + extend_lens[-1])
344
+
345
+ position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
346
+
347
+ # Allocate memory
348
+ seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
349
+ extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
350
+ out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
351
+ if out_cache_loc is None:
352
+ self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
353
+ out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
354
+
355
+ if out_cache_loc is None:
356
+ print("Prefill out of memory. This should never happen.")
357
+ self.tree_cache.pretty_print()
358
+ exit()
359
+
360
+ pt = 0
361
+ for i in range(bs):
362
+ self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
363
+ prefix_lens[i] : prefix_lens[i] + extend_lens[i]
364
+ ] = out_cache_loc[pt : pt + extend_lens[i]]
365
+ pt += extend_lens[i]
366
+
367
+ # Handle logit bias but only allocate when needed
368
+ logit_bias = None
369
+ for i in range(bs):
370
+ if reqs[i].sampling_params.dtype == "int":
371
+ if logit_bias is None:
372
+ logit_bias = torch.zeros(
373
+ (bs, vocab_size), dtype=torch.float32, device=device
374
+ )
375
+ logit_bias[i] = int_token_logit_bias
376
+
377
+ # Set fields
378
+ self.input_ids = torch.tensor(
379
+ flatten_input_ids, dtype=torch.int32, device=device
380
+ )
381
+ self.pixel_values = [r.pixel_values for r in reqs]
382
+ self.image_sizes = [r.image_size for r in reqs]
383
+ self.image_offsets = [
384
+ r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
385
+ ]
386
+ self.req_pool_indices = req_pool_indices
387
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
388
+ self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
389
+ self.position_ids_offsets = position_ids_offsets
390
+ self.extend_num_tokens = extend_num_tokens
391
+ self.out_cache_loc = out_cache_loc
392
+ self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
393
+
394
+ self.temperatures = torch.tensor(
395
+ [r.sampling_params.temperature for r in reqs],
396
+ dtype=torch.float,
397
+ device=device,
398
+ ).view(-1, 1)
399
+ self.top_ps = torch.tensor(
400
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
401
+ ).view(-1, 1)
402
+ self.top_ks = torch.tensor(
403
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
404
+ ).view(-1, 1)
405
+ self.frequency_penalties = torch.tensor(
406
+ [r.sampling_params.frequency_penalty for r in reqs],
407
+ dtype=torch.float,
408
+ device=device,
409
+ )
410
+ self.presence_penalties = torch.tensor(
411
+ [r.sampling_params.presence_penalty for r in reqs],
412
+ dtype=torch.float,
413
+ device=device,
414
+ )
415
+ self.logit_bias = logit_bias
416
+
417
+ def check_decode_mem(self):
418
+ bs = len(self.reqs)
419
+ if self.token_to_kv_pool.available_size() >= bs:
420
+ return True
421
+
422
+ self.tree_cache.evict(bs, self.token_to_kv_pool.free)
423
+
424
+ if self.token_to_kv_pool.available_size() >= bs:
425
+ return True
426
+
427
+ return False
428
+
429
+ def retract_decode(self):
430
+ sorted_indices = [i for i in range(len(self.reqs))]
431
+ # TODO(lsyin): improve the priority of retraction
432
+ sorted_indices.sort(
433
+ key=lambda i: (
434
+ len(self.reqs[i].output_ids),
435
+ -len(self.reqs[i].origin_input_ids),
436
+ ),
437
+ reverse=True,
438
+ )
439
+
440
+ retracted_reqs = []
441
+ seq_lens_cpu = self.seq_lens.cpu().numpy()
442
+ req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
443
+ while self.token_to_kv_pool.available_size() < len(self.reqs):
444
+ idx = sorted_indices.pop()
445
+ req = self.reqs[idx]
446
+ retracted_reqs.append(req)
447
+
448
+ # TODO: apply more fine-grained retraction
449
+ last_uncached_pos = len(req.prefix_indices)
450
+ token_indices = self.req_to_token_pool.req_to_token[
451
+ req_pool_indices_cpu[idx]
452
+ ][last_uncached_pos : seq_lens_cpu[idx]]
453
+ self.token_to_kv_pool.free(token_indices)
454
+
455
+ # release the last node
456
+ self.tree_cache.dec_lock_ref(req.last_node)
457
+
458
+ req.prefix_indices = None
459
+ req.last_node = None
460
+ req.extend_input_len = 0
461
+
462
+ # For incremental logprobs
463
+ req.last_update_decode_tokens = 0
464
+ req.logprob_start_len = 10**9
465
+
466
+ self.filter_batch(sorted_indices)
467
+
468
+ return retracted_reqs
469
+
470
+ def check_for_jump_forward(self, model_runner):
471
+ jump_forward_reqs = []
472
+ filter_indices = [i for i in range(len(self.reqs))]
473
+
474
+ req_pool_indices_cpu = None
475
+
476
+ for i, req in enumerate(self.reqs):
477
+ if req.jump_forward_map is not None:
478
+ jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
479
+ req.regex_fsm_state
480
+ )
481
+ if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
482
+ suffix_bytes = []
483
+ continuation_range = range(0x80, 0xC0)
484
+ cur_state = req.regex_fsm_state
485
+ while (
486
+ len(jump_forward_bytes)
487
+ and jump_forward_bytes[0][0] in continuation_range
488
+ ):
489
+ # continuation bytes
490
+ byte_edge = jump_forward_bytes.pop(0)
491
+ suffix_bytes.append(byte_edge[0])
492
+ cur_state = byte_edge[1]
493
+
494
+ suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
495
+ suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
496
+
497
+ # Current ids, for cache and revert
498
+ cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
499
+ cur_output_ids = req.output_ids
500
+
501
+ req.output_ids.extend(suffix_ids)
502
+ decode_res, new_text = req.detokenize_incrementally(inplace=False)
503
+ if not decode_res:
504
+ req.output_ids = cur_output_ids
505
+ continue
506
+
507
+ (
508
+ jump_forward_str,
509
+ next_state,
510
+ ) = req.jump_forward_map.jump_forward_symbol(cur_state)
511
+
512
+ # Make the incrementally decoded text part of jump_forward_str
513
+ # so that the UTF-8 will not corrupt
514
+ jump_forward_str = new_text + jump_forward_str
515
+ if not req.jump_forward_and_retokenize(
516
+ jump_forward_str, next_state
517
+ ):
518
+ req.output_ids = cur_output_ids
519
+ continue
520
+
521
+ # insert the old request into tree_cache
522
+ if req_pool_indices_cpu is None:
523
+ req_pool_indices_cpu = self.req_pool_indices.tolist()
524
+ self.tree_cache.cache_req(
525
+ token_ids=cur_all_ids,
526
+ last_uncached_pos=len(req.prefix_indices),
527
+ req_pool_idx=req_pool_indices_cpu[i],
528
+ )
529
+
530
+ # unlock the last node
531
+ self.tree_cache.dec_lock_ref(req.last_node)
532
+
533
+ # re-applying image padding
534
+ if req.pixel_values is not None:
535
+ (
536
+ req.origin_input_ids,
537
+ req.image_offset,
538
+ ) = model_runner.model.pad_input_ids(
539
+ req.origin_input_ids_unpadded,
540
+ req.pad_value,
541
+ req.pixel_values.shape,
542
+ req.image_size,
543
+ )
544
+
545
+ jump_forward_reqs.append(req)
546
+ filter_indices.remove(i)
547
+
548
+ if len(filter_indices) < len(self.reqs):
549
+ self.filter_batch(filter_indices)
550
+
551
+ return jump_forward_reqs
552
+
553
+ def prepare_for_decode(self, input_ids=None):
554
+ if input_ids is None:
555
+ input_ids = [
556
+ r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
557
+ ]
558
+ self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
559
+ self.seq_lens.add_(1)
560
+ self.prefix_lens = None
561
+
562
+ # Alloc mem
563
+ bs = len(self.reqs)
564
+ self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
565
+
566
+ if self.out_cache_loc is None:
567
+ print("Decode out of memory. This should never happen.")
568
+ self.tree_cache.pretty_print()
569
+ exit()
570
+
571
+ self.req_to_token_pool.req_to_token[
572
+ self.req_pool_indices, self.seq_lens - 1
573
+ ] = self.out_cache_loc
574
+
575
+ def filter_batch(self, unfinished_indices: List[int]):
576
+ self.reqs = [self.reqs[i] for i in unfinished_indices]
577
+ new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
578
+ self.seq_lens = self.seq_lens[new_indices]
579
+ self.input_ids = None
580
+ self.req_pool_indices = self.req_pool_indices[new_indices]
581
+ self.prefix_lens = None
582
+ self.position_ids_offsets = self.position_ids_offsets[new_indices]
583
+ self.out_cache_loc = None
584
+ self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
585
+ self.return_logprob = any(req.return_logprob for req in self.reqs)
586
+
587
+ for item in [
588
+ "temperatures",
589
+ "top_ps",
590
+ "top_ks",
591
+ "frequency_penalties",
592
+ "presence_penalties",
593
+ "logit_bias",
594
+ ]:
595
+ self_val = getattr(self, item, None)
596
+ if self_val is not None: # logit_bias can be None
597
+ setattr(self, item, self_val[new_indices])
598
+
599
+ def merge(self, other: "Batch"):
600
+ self.reqs.extend(other.reqs)
601
+
602
+ self.req_pool_indices = torch.concat(
603
+ [self.req_pool_indices, other.req_pool_indices]
604
+ )
605
+ self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
606
+ self.prefix_lens = None
607
+ self.position_ids_offsets = torch.concat(
608
+ [self.position_ids_offsets, other.position_ids_offsets]
609
+ )
610
+ self.out_cache_loc = None
611
+ self.top_logprobs_nums.extend(other.top_logprobs_nums)
612
+ self.return_logprob = any(req.return_logprob for req in self.reqs)
613
+
614
+ for item in [
615
+ "temperatures",
616
+ "top_ps",
617
+ "top_ks",
618
+ "frequency_penalties",
619
+ "presence_penalties",
620
+ ]:
621
+ self_val = getattr(self, item, None)
622
+ other_val = getattr(other, item, None)
623
+ setattr(self, item, torch.concat([self_val, other_val]))
624
+
625
+ # logit_bias can be None
626
+ if self.logit_bias is not None or other.logit_bias is not None:
627
+ vocab_size = (
628
+ self.logit_bias.shape[1]
629
+ if self.logit_bias is not None
630
+ else other.logit_bias.shape[1]
631
+ )
632
+ if self.logit_bias is None:
633
+ self.logit_bias = torch.zeros(
634
+ (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
635
+ )
636
+ if other.logit_bias is None:
637
+ other.logit_bias = torch.zeros(
638
+ (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
639
+ )
640
+ self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
641
+
642
+ def sample(self, logits: torch.Tensor):
643
+ # Post process logits
644
+ logits = logits.contiguous()
645
+ logits.div_(self.temperatures)
646
+ if self.logit_bias is not None:
647
+ logits.add_(self.logit_bias)
648
+
649
+ has_regex = any(req.regex_fsm is not None for req in self.reqs)
650
+ if has_regex:
651
+ allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
652
+ for i, req in enumerate(self.reqs):
653
+ if req.regex_fsm is not None:
654
+ allowed_mask.zero_()
655
+ allowed_mask[
656
+ req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
657
+ ] = 1
658
+ logits[i].masked_fill_(~allowed_mask, float("-inf"))
659
+
660
+ # TODO(lmzheng): apply penalty
661
+ probs = torch.softmax(logits, dim=-1)
662
+ probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
663
+ try:
664
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
665
+ except RuntimeError as e:
666
+ warnings.warn(f"Ignore errors in sampling: {e}")
667
+ sampled_index = torch.ones(
668
+ probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
669
+ )
670
+ batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
671
+ -1
672
+ )
673
+ batch_next_token_probs = torch.gather(
674
+ probs_sort, dim=1, index=sampled_index
675
+ ).view(-1)
676
+
677
+ if has_regex:
678
+ batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
679
+ for i, req in enumerate(self.reqs):
680
+ if req.regex_fsm is not None:
681
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
682
+ req.regex_fsm_state, batch_next_token_ids_cpu[i]
683
+ )
684
+
685
+ return batch_next_token_ids, batch_next_token_probs
686
+
687
+
688
+ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
689
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
690
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
691
+ probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
692
+ probs_sort[
693
+ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
694
+ ] = 0.0
695
+ probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
696
+ return probs_sort, probs_idx
697
+
698
+
699
+ @dataclass
700
+ class InputMetadata:
701
+ """Store all inforamtion of a forward pass."""
702
+
703
+ forward_mode: ForwardMode
704
+ batch_size: int
705
+ total_num_tokens: int
706
+ req_pool_indices: torch.Tensor
707
+ seq_lens: torch.Tensor
708
+ positions: torch.Tensor
709
+ req_to_token_pool: ReqToTokenPool
710
+ token_to_kv_pool: TokenToKVPool
711
+
712
+ # For extend
713
+ extend_seq_lens: torch.Tensor
714
+ extend_start_loc: torch.Tensor
715
+ extend_no_prefix: bool
716
+
717
+ # Output location of the KV cache
718
+ out_cache_loc: torch.Tensor = None
719
+
720
+ # Output options
721
+ return_logprob: bool = False
722
+ top_logprobs_nums: List[int] = None
723
+
724
+ # Trition attention backend
725
+ triton_max_seq_len: int = 0
726
+ triton_max_extend_len: int = 0
727
+ triton_start_loc: torch.Tensor = None
728
+ triton_prefix_lens: torch.Tensor = None
729
+
730
+ # FlashInfer attention backend
731
+ flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
732
+ flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
733
+ flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
734
+
735
+ @classmethod
736
+ def create(
737
+ cls,
738
+ model_runner,
739
+ forward_mode,
740
+ req_pool_indices,
741
+ seq_lens,
742
+ prefix_lens,
743
+ position_ids_offsets,
744
+ out_cache_loc,
745
+ top_logprobs_nums=None,
746
+ return_logprob=False,
747
+ skip_flashinfer_init=False,
748
+ ):
749
+ if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
750
+ init_flashinfer_args(
751
+ forward_mode,
752
+ model_runner,
753
+ req_pool_indices,
754
+ seq_lens,
755
+ prefix_lens,
756
+ model_runner.flashinfer_decode_wrapper,
757
+ )
758
+
759
+ batch_size = len(req_pool_indices)
760
+
761
+ if forward_mode == ForwardMode.DECODE:
762
+ positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
763
+ extend_seq_lens = extend_start_loc = extend_no_prefix = None
764
+ if not model_runner.server_args.disable_flashinfer:
765
+ # This variable is not needed in this case,
766
+ # we do not compute it to make it compatbile with cuda graph.
767
+ total_num_tokens = None
768
+ else:
769
+ total_num_tokens = int(torch.sum(seq_lens))
770
+ else:
771
+ seq_lens_cpu = seq_lens.cpu().numpy()
772
+ prefix_lens_cpu = prefix_lens.cpu().numpy()
773
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
774
+ positions = torch.tensor(
775
+ np.concatenate(
776
+ [
777
+ np.arange(
778
+ prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
779
+ seq_lens_cpu[i] + position_ids_offsets_cpu[i],
780
+ )
781
+ for i in range(batch_size)
782
+ ],
783
+ axis=0,
784
+ ),
785
+ device="cuda",
786
+ )
787
+ extend_seq_lens = seq_lens - prefix_lens
788
+ extend_start_loc = torch.zeros_like(seq_lens)
789
+ extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
790
+ extend_no_prefix = torch.all(prefix_lens == 0)
791
+ total_num_tokens = int(torch.sum(seq_lens))
792
+
793
+ ret = cls(
794
+ forward_mode=forward_mode,
795
+ batch_size=batch_size,
796
+ total_num_tokens=total_num_tokens,
797
+ req_pool_indices=req_pool_indices,
798
+ seq_lens=seq_lens,
799
+ positions=positions,
800
+ req_to_token_pool=model_runner.req_to_token_pool,
801
+ token_to_kv_pool=model_runner.token_to_kv_pool,
802
+ out_cache_loc=out_cache_loc,
803
+ extend_seq_lens=extend_seq_lens,
804
+ extend_start_loc=extend_start_loc,
805
+ extend_no_prefix=extend_no_prefix,
806
+ return_logprob=return_logprob,
807
+ top_logprobs_nums=top_logprobs_nums,
808
+ flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
809
+ flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
810
+ flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
811
+ )
812
+
813
+ if model_runner.server_args.disable_flashinfer:
814
+ (
815
+ ret.triton_max_seq_len,
816
+ ret.triton_max_extend_len,
817
+ ret.triton_start_loc,
818
+ ret.triton_prefix_lens,
819
+ ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
820
+
821
+ return ret
822
+
823
+
824
+ def init_flashinfer_args(
825
+ forward_mode,
826
+ model_runner,
827
+ req_pool_indices,
828
+ seq_lens,
829
+ prefix_lens,
830
+ flashinfer_decode_wrapper,
831
+ ):
832
+ num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
833
+ num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
834
+ head_dim = model_runner.model_config.head_dim
835
+ batch_size = len(req_pool_indices)
836
+
837
+ if forward_mode == ForwardMode.DECODE:
838
+ paged_kernel_lens = seq_lens
839
+ else:
840
+ paged_kernel_lens = prefix_lens
841
+
842
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
843
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
844
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
845
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
846
+ kv_indices = torch.cat(
847
+ [
848
+ model_runner.req_to_token_pool.req_to_token[
849
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
850
+ ]
851
+ for i in range(batch_size)
852
+ ],
853
+ dim=0,
854
+ ).contiguous()
855
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
856
+
857
+ if forward_mode == ForwardMode.DECODE:
858
+ flashinfer_decode_wrapper.end_forward()
859
+ flashinfer_decode_wrapper.begin_forward(
860
+ kv_indptr,
861
+ kv_indices,
862
+ kv_last_page_len,
863
+ num_qo_heads,
864
+ num_kv_heads,
865
+ head_dim,
866
+ 1,
867
+ )
868
+ else:
869
+ # extend part
870
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
871
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
872
+
873
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
874
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
875
+ qo_indptr,
876
+ qo_indptr,
877
+ num_qo_heads,
878
+ num_kv_heads,
879
+ head_dim,
880
+ )
881
+
882
+ # cached part
883
+ model_runner.flashinfer_prefill_wrapper_paged.end_forward()
884
+ model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
885
+ qo_indptr,
886
+ kv_indptr,
887
+ kv_indices,
888
+ kv_last_page_len,
889
+ num_qo_heads,
890
+ num_kv_heads,
891
+ head_dim,
892
+ 1,
893
+ )
894
+
895
+
896
+ def init_triton_args(forward_mode, seq_lens, prefix_lens):
897
+ batch_size = len(seq_lens)
898
+ max_seq_len = int(torch.max(seq_lens))
899
+ start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
900
+ start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
901
+
902
+ if forward_mode == ForwardMode.DECODE:
903
+ max_extend_len = None
904
+ else:
905
+ extend_seq_lens = seq_lens - prefix_lens
906
+ max_extend_len = int(torch.max(extend_seq_lens))
907
+
908
+ return max_seq_len, max_extend_len, start_loc, prefix_lens