sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +151 -40
  4. sglang/bench_serving.py +46 -22
  5. sglang/check_env.py +24 -2
  6. sglang/global_config.py +0 -1
  7. sglang/lang/backend/base_backend.py +3 -1
  8. sglang/lang/backend/openai.py +8 -3
  9. sglang/lang/backend/runtime_endpoint.py +46 -29
  10. sglang/lang/choices.py +164 -0
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +6 -13
  13. sglang/lang/ir.py +14 -5
  14. sglang/srt/constrained/base_tool_cache.py +1 -1
  15. sglang/srt/constrained/fsm_cache.py +12 -2
  16. sglang/srt/layers/activation.py +33 -0
  17. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  18. sglang/srt/layers/extend_attention.py +6 -1
  19. sglang/srt/layers/layernorm.py +65 -0
  20. sglang/srt/layers/logits_processor.py +6 -1
  21. sglang/srt/layers/pooler.py +50 -0
  22. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  23. sglang/srt/layers/radix_attention.py +4 -7
  24. sglang/srt/managers/detokenizer_manager.py +31 -9
  25. sglang/srt/managers/io_struct.py +63 -0
  26. sglang/srt/managers/policy_scheduler.py +173 -25
  27. sglang/srt/managers/schedule_batch.py +174 -380
  28. sglang/srt/managers/tokenizer_manager.py +197 -112
  29. sglang/srt/managers/tp_worker.py +299 -364
  30. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  31. sglang/srt/mem_cache/chunk_cache.py +43 -20
  32. sglang/srt/mem_cache/memory_pool.py +10 -15
  33. sglang/srt/mem_cache/radix_cache.py +74 -40
  34. sglang/srt/model_executor/cuda_graph_runner.py +27 -12
  35. sglang/srt/model_executor/forward_batch_info.py +319 -0
  36. sglang/srt/model_executor/model_runner.py +30 -47
  37. sglang/srt/models/chatglm.py +1 -1
  38. sglang/srt/models/commandr.py +1 -1
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/deepseek.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +1 -1
  42. sglang/srt/models/gemma.py +1 -1
  43. sglang/srt/models/gemma2.py +1 -2
  44. sglang/srt/models/gpt_bigcode.py +1 -1
  45. sglang/srt/models/grok.py +1 -1
  46. sglang/srt/models/internlm2.py +3 -8
  47. sglang/srt/models/llama2.py +5 -5
  48. sglang/srt/models/llama_classification.py +1 -1
  49. sglang/srt/models/llama_embedding.py +88 -0
  50. sglang/srt/models/llava.py +1 -2
  51. sglang/srt/models/llavavid.py +1 -2
  52. sglang/srt/models/minicpm.py +1 -1
  53. sglang/srt/models/mixtral.py +1 -1
  54. sglang/srt/models/mixtral_quant.py +1 -1
  55. sglang/srt/models/qwen.py +1 -1
  56. sglang/srt/models/qwen2.py +1 -1
  57. sglang/srt/models/qwen2_moe.py +1 -12
  58. sglang/srt/models/stablelm.py +1 -1
  59. sglang/srt/openai_api/adapter.py +189 -39
  60. sglang/srt/openai_api/protocol.py +43 -1
  61. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  62. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  63. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  64. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  65. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  66. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  67. sglang/srt/sampling_params.py +31 -4
  68. sglang/srt/server.py +93 -21
  69. sglang/srt/server_args.py +30 -19
  70. sglang/srt/utils.py +31 -13
  71. sglang/test/run_eval.py +10 -1
  72. sglang/test/runners.py +63 -63
  73. sglang/test/simple_eval_humaneval.py +2 -8
  74. sglang/test/simple_eval_mgsm.py +203 -0
  75. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  76. sglang/test/test_layernorm.py +60 -0
  77. sglang/test/test_programs.py +4 -2
  78. sglang/test/test_utils.py +21 -3
  79. sglang/utils.py +0 -1
  80. sglang/version.py +1 -1
  81. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
  82. sglang-0.2.12.dist-info/RECORD +112 -0
  83. sglang/srt/layers/linear.py +0 -884
  84. sglang/srt/layers/quantization/__init__.py +0 -64
  85. sglang/srt/layers/quantization/fp8.py +0 -677
  86. sglang-0.2.10.dist-info/RECORD +0 -100
  87. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  88. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  89. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -18,19 +18,18 @@ limitations under the License.
18
18
  import logging
19
19
  import warnings
20
20
  from dataclasses import dataclass
21
- from enum import IntEnum, auto
22
- from typing import List, Union
21
+ from typing import List, Optional, Union
23
22
 
24
- import numpy as np
25
23
  import torch
26
24
  from flashinfer.sampling import top_k_top_p_sampling_from_probs
27
25
 
26
+ import sglang.srt.sampling.penaltylib as penaltylib
28
27
  from sglang.global_config import global_config
29
28
  from sglang.srt.constrained import RegexGuide
30
29
  from sglang.srt.constrained.jump_forward import JumpForwardMap
30
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
31
31
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
32
32
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
33
- from sglang.srt.mem_cache.radix_cache import RadixCache
34
33
 
35
34
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
36
35
 
@@ -46,15 +45,6 @@ global_server_args_dict = {
46
45
  logger = logging.getLogger(__name__)
47
46
 
48
47
 
49
- class ForwardMode(IntEnum):
50
- # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
51
- PREFILL = auto()
52
- # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
53
- EXTEND = auto()
54
- # Decode one token.
55
- DECODE = auto()
56
-
57
-
58
48
  class BaseFinishReason:
59
49
  def __init__(self, is_error: bool = False):
60
50
  self.is_error = is_error
@@ -108,7 +98,10 @@ class Req:
108
98
  self.origin_input_ids_unpadded = origin_input_ids # Before image padding
109
99
  self.origin_input_ids = origin_input_ids
110
100
  self.output_ids = [] # Each decode stage's output ids
111
- self.input_ids = None # input_ids = origin_input_ids + output_ids
101
+ self.fill_ids = None # fill_ids = origin_input_ids + output_ids
102
+
103
+ # Memory info
104
+ self.req_pool_idx = None
112
105
 
113
106
  # For incremental decoding
114
107
  # ----- | --------- read_ids -------|
@@ -131,7 +124,7 @@ class Req:
131
124
  # For vision input
132
125
  self.pixel_values = None
133
126
  self.image_size = None
134
- self.image_offset = 0
127
+ self.image_offset = None
135
128
  self.pad_value = None
136
129
 
137
130
  # Prefix info
@@ -149,6 +142,7 @@ class Req:
149
142
 
150
143
  # Logprobs
151
144
  self.return_logprob = False
145
+ self.embedding = None
152
146
  self.logprob_start_len = 0
153
147
  self.top_logprobs_num = 0
154
148
  self.normalized_prompt_logprob = None
@@ -169,6 +163,32 @@ class Req:
169
163
  def finished(self) -> bool:
170
164
  return self.finished_reason is not None
171
165
 
166
+ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
167
+ self.fill_ids = self.origin_input_ids + self.output_ids
168
+ if tree_cache is not None:
169
+ self.prefix_indices, self.last_node = tree_cache.match_prefix(
170
+ rid=self.rid, key=self.adjust_max_prefix_ids()
171
+ )
172
+ self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
173
+
174
+ def adjust_max_prefix_ids(self):
175
+ self.fill_ids = self.origin_input_ids + self.output_ids
176
+ input_len = len(self.fill_ids)
177
+ max_prefix_len = input_len
178
+
179
+ if self.sampling_params.max_new_tokens > 0:
180
+ # Need at least one token to compute logits
181
+ max_prefix_len = min(max_prefix_len, input_len - 1)
182
+
183
+ if self.return_logprob:
184
+ max_prefix_len = min(max_prefix_len, self.logprob_start_len)
185
+
186
+ if self.normalized_prompt_logprob is None:
187
+ # Need at least two tokens to compute normalized logprob
188
+ max_prefix_len = min(max_prefix_len, input_len - 2)
189
+
190
+ return self.fill_ids[:max_prefix_len]
191
+
172
192
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
173
193
  def init_incremental_detokenize(self):
174
194
  first_iter = self.surr_offset is None or self.read_offset is None
@@ -183,6 +203,8 @@ class Req:
183
203
  return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
184
204
 
185
205
  def get_next_inc_detokenization(self):
206
+ if self.tokenizer is None:
207
+ return False, ""
186
208
  read_ids, read_offset = self.init_incremental_detokenize()
187
209
  surr_ids = read_ids[:read_offset]
188
210
 
@@ -207,16 +229,18 @@ class Req:
207
229
  return
208
230
 
209
231
  if len(self.output_ids) >= self.sampling_params.max_new_tokens:
210
- self.finished_reason = FINISH_LENGTH(len(self.output_ids))
232
+ self.finished_reason = FINISH_LENGTH(
233
+ length=self.sampling_params.max_new_tokens
234
+ )
211
235
  return
212
236
 
213
- if (
214
- self.output_ids[-1] == self.tokenizer.eos_token_id
215
- and not self.sampling_params.ignore_eos
216
- ):
217
- self.finished_reason = FINISH_MATCHED_TOKEN(
218
- matched=self.tokenizer.eos_token_id
219
- )
237
+ last_token_id = self.output_ids[-1]
238
+ if self.tokenizer is None:
239
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
240
+ else:
241
+ matched_eos = last_token_id == self.tokenizer.eos_token_id
242
+ if matched_eos and not self.sampling_params.ignore_eos:
243
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
220
244
  return
221
245
 
222
246
  if len(self.sampling_params.stop_strs) > 0:
@@ -284,20 +308,19 @@ class Req:
284
308
 
285
309
 
286
310
  @dataclass
287
- class Batch:
311
+ class ScheduleBatch:
288
312
  """Store all inforamtion of a batch."""
289
313
 
290
314
  # Request, memory pool, and cache
291
315
  reqs: List[Req]
292
316
  req_to_token_pool: ReqToTokenPool
293
317
  token_to_kv_pool: BaseTokenToKVPool
294
- tree_cache: RadixCache
318
+ tree_cache: BasePrefixCache
295
319
 
296
320
  # Batched arguments to model runner
297
321
  input_ids: torch.Tensor = None
298
322
  req_pool_indices: torch.Tensor = None
299
323
  seq_lens: torch.Tensor = None
300
- prefix_lens: torch.Tensor = None
301
324
  position_ids_offsets: torch.Tensor = None
302
325
  out_cache_loc: torch.Tensor = None
303
326
  extend_num_tokens: int = None
@@ -306,17 +329,11 @@ class Batch:
306
329
  return_logprob: bool = False
307
330
  top_logprobs_nums: List[int] = None
308
331
 
309
- # For multimodal
310
- pixel_values: List[torch.Tensor] = None
311
- image_sizes: List[List[int]] = None
312
- image_offsets: List[int] = None
313
-
314
332
  # Batched sampling params
315
333
  temperatures: torch.Tensor = None
316
334
  top_ps: torch.Tensor = None
317
335
  top_ks: torch.Tensor = None
318
- frequency_penalties: torch.Tensor = None
319
- presence_penalties: torch.Tensor = None
336
+ penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
320
337
  logit_bias: torch.Tensor = None
321
338
 
322
339
  @classmethod
@@ -331,6 +348,9 @@ class Batch:
331
348
  return_logprob=return_logprob,
332
349
  )
333
350
 
351
+ def batch_size(self):
352
+ return len(self.reqs) if self.reqs is not None else 0
353
+
334
354
  def is_empty(self):
335
355
  return len(self.reqs) == 0
336
356
 
@@ -338,52 +358,22 @@ class Batch:
338
358
  # Return whether batch has at least 1 streaming request
339
359
  return any(r.stream for r in self.reqs)
340
360
 
341
- def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
342
- device = "cuda"
343
- bs = len(self.reqs)
344
- reqs = self.reqs
345
- input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
346
- prefix_indices = [r.prefix_indices for r in reqs]
347
-
348
- # Handle prefix
349
- flatten_input_ids = []
350
- extend_lens = []
351
- prefix_lens = []
352
- seq_lens = []
353
-
354
- req_pool_indices = self.req_to_token_pool.alloc(bs)
355
-
361
+ def alloc_req_slots(self, num_reqs):
362
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
356
363
  if req_pool_indices is None:
357
364
  raise RuntimeError(
358
365
  "Out of memory. "
359
366
  "Please set a smaller number for `--max-running-requests`."
360
367
  )
368
+ return req_pool_indices
361
369
 
362
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
363
- for i in range(bs):
364
- flatten_input_ids.extend(input_ids[i])
365
- extend_lens.append(len(input_ids[i]))
366
-
367
- if len(prefix_indices[i]) == 0:
368
- prefix_lens.append(0)
369
- else:
370
- prefix_lens.append(len(prefix_indices[i]))
371
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
372
- : len(prefix_indices[i])
373
- ] = prefix_indices[i]
374
-
375
- seq_lens.append(prefix_lens[-1] + extend_lens[-1])
370
+ def alloc_token_slots(self, num_tokens: int):
371
+ out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
376
372
 
377
- position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
378
-
379
- # Allocate memory
380
- seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
381
- extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
382
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
383
373
  if out_cache_loc is None:
384
374
  if self.tree_cache is not None:
385
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
386
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
375
+ self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
376
+ out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
387
377
 
388
378
  if out_cache_loc is None:
389
379
  logger.error("Prefill out of memory. Try to lower your batch size.")
@@ -391,40 +381,11 @@ class Batch:
391
381
  self.tree_cache.pretty_print()
392
382
  exit(1)
393
383
 
394
- pt = 0
395
- for i in range(bs):
396
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
397
- prefix_lens[i] : prefix_lens[i] + extend_lens[i]
398
- ] = out_cache_loc[pt : pt + extend_lens[i]]
399
- pt += extend_lens[i]
400
-
401
- # Handle logit bias but only allocate when needed
402
- logit_bias = None
403
- for i in range(bs):
404
- if reqs[i].sampling_params.dtype == "int":
405
- if logit_bias is None:
406
- logit_bias = torch.zeros(
407
- (bs, vocab_size), dtype=torch.float32, device=device
408
- )
409
- logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
410
-
411
- # Set fields
412
- self.input_ids = torch.tensor(
413
- flatten_input_ids, dtype=torch.int32, device=device
414
- )
415
- self.pixel_values = [r.pixel_values for r in reqs]
416
- self.image_sizes = [r.image_size for r in reqs]
417
- self.image_offsets = [
418
- r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
419
- ]
420
- self.req_pool_indices = req_pool_indices
421
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
422
- self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
423
- self.position_ids_offsets = position_ids_offsets
424
- self.extend_num_tokens = extend_num_tokens
425
- self.out_cache_loc = out_cache_loc
426
- self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
384
+ return out_cache_loc
427
385
 
386
+ def batch_sampling_params(self, vocab_size, int_token_logit_bias):
387
+ device = "cuda"
388
+ bs, reqs = self.batch_size(), self.reqs
428
389
  self.temperatures = torch.tensor(
429
390
  [r.sampling_params.temperature for r in reqs],
430
391
  dtype=torch.float,
@@ -436,20 +397,79 @@ class Batch:
436
397
  self.top_ks = torch.tensor(
437
398
  [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
438
399
  )
439
- self.frequency_penalties = torch.tensor(
440
- [r.sampling_params.frequency_penalty for r in reqs],
441
- dtype=torch.float,
442
- device=device,
443
- )
444
- self.presence_penalties = torch.tensor(
445
- [r.sampling_params.presence_penalty for r in reqs],
446
- dtype=torch.float,
400
+
401
+ # Each penalizers will do nothing if they evaluate themselves as not required by looking at
402
+ # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
403
+ # should not add hefty computation overhead other than simple checks.
404
+ #
405
+ # While we choose not to even create the class instances if they are not required, this
406
+ # could add additional complexity to the {ScheduleBatch} class, especially we need to
407
+ # handle {filter_batch()} and {merge()} cases as well.
408
+ self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
409
+ vocab_size=vocab_size,
410
+ batch=self,
447
411
  device=device,
412
+ Penalizers={
413
+ penaltylib.BatchedFrequencyPenalizer,
414
+ penaltylib.BatchedMinNewTokensPenalizer,
415
+ penaltylib.BatchedPresencePenalizer,
416
+ penaltylib.BatchedRepetitionPenalizer,
417
+ },
448
418
  )
449
- self.logit_bias = logit_bias
419
+
420
+ # Handle logit bias but only allocate when needed
421
+ self.logit_bias = None
422
+ for i in range(bs):
423
+ if reqs[i].sampling_params.dtype == "int":
424
+ if self.logit_bias is None:
425
+ self.logit_bias = torch.zeros(
426
+ (bs, vocab_size), dtype=torch.float32, device=device
427
+ )
428
+ self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
429
+
430
+ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
431
+ bs = self.batch_size()
432
+ reqs = self.reqs
433
+ input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
434
+ extend_num_tokens = sum(len(ids) for ids in input_ids)
435
+ seq_lens = []
436
+
437
+ # Allocate memory
438
+ req_pool_indices_cpu = self.alloc_req_slots(bs)
439
+ out_cache_loc = self.alloc_token_slots(extend_num_tokens)
440
+
441
+ pt = 0
442
+ for i, req in enumerate(reqs):
443
+ req.req_pool_idx = req_pool_indices_cpu[i]
444
+ pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
445
+ ext_len = seq_len - pre_len
446
+ seq_lens.append(seq_len)
447
+
448
+ if pre_len > 0:
449
+ self.req_to_token_pool.req_to_token[req.req_pool_idx][
450
+ :pre_len
451
+ ] = req.prefix_indices
452
+
453
+ self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
454
+ out_cache_loc[pt : pt + ext_len]
455
+ )
456
+ pt += ext_len
457
+
458
+ # Set fields
459
+ with torch.device("cuda"):
460
+ self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
461
+ self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
462
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
463
+ self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
464
+
465
+ self.extend_num_tokens = extend_num_tokens
466
+ self.out_cache_loc = out_cache_loc
467
+ self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
468
+
469
+ self.batch_sampling_params(vocab_size, int_token_logit_bias)
450
470
 
451
471
  def check_decode_mem(self):
452
- bs = len(self.reqs)
472
+ bs = self.batch_size()
453
473
  if self.token_to_kv_pool.available_size() >= bs:
454
474
  return True
455
475
 
@@ -474,7 +494,6 @@ class Batch:
474
494
 
475
495
  retracted_reqs = []
476
496
  seq_lens_cpu = self.seq_lens.cpu().numpy()
477
- req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
478
497
  while (
479
498
  self.token_to_kv_pool.available_size()
480
499
  < len(sorted_indices) * global_config.retract_decode_steps
@@ -492,20 +511,20 @@ class Batch:
492
511
 
493
512
  if isinstance(self.tree_cache, ChunkCache):
494
513
  # ChunkCache does not have eviction
495
- token_indices = self.req_to_token_pool.req_to_token[
496
- req_pool_indices_cpu[idx]
497
- ][: seq_lens_cpu[idx]]
514
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
515
+ : seq_lens_cpu[idx]
516
+ ]
498
517
  self.token_to_kv_pool.free(token_indices)
499
- self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
518
+ self.req_to_token_pool.free(req.req_pool_idx)
500
519
  del self.tree_cache.entries[req.rid]
501
520
  else:
502
521
  # TODO: apply more fine-grained retraction
503
522
  last_uncached_pos = len(req.prefix_indices)
504
- token_indices = self.req_to_token_pool.req_to_token[
505
- req_pool_indices_cpu[idx]
506
- ][last_uncached_pos : seq_lens_cpu[idx]]
523
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
524
+ last_uncached_pos : seq_lens_cpu[idx]
525
+ ]
507
526
  self.token_to_kv_pool.free(token_indices)
508
- self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
527
+ self.req_to_token_pool.free(req.req_pool_idx)
509
528
 
510
529
  # release the last node
511
530
  self.tree_cache.dec_lock_ref(req.last_node)
@@ -518,7 +537,7 @@ class Batch:
518
537
  residual_size = max(0, residual_size)
519
538
  self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
520
539
 
521
- req.prefix_indices = None
540
+ req.prefix_indices = []
522
541
  req.last_node = None
523
542
  req.extend_input_len = 0
524
543
 
@@ -543,8 +562,6 @@ class Batch:
543
562
  jump_forward_reqs = []
544
563
  filter_indices = [i for i in range(len(self.reqs))]
545
564
 
546
- req_pool_indices_cpu = None
547
-
548
565
  for i, req in enumerate(self.reqs):
549
566
  if req.jump_forward_map is not None:
550
567
  jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
@@ -594,17 +611,7 @@ class Batch:
594
611
  req.vid += 1
595
612
 
596
613
  # insert the old request into tree_cache
597
- if req_pool_indices_cpu is None:
598
- req_pool_indices_cpu = self.req_pool_indices.tolist()
599
- self.tree_cache.cache_req(
600
- rid=req.rid,
601
- token_ids=cur_all_ids,
602
- last_uncached_pos=len(req.prefix_indices),
603
- req_pool_idx=req_pool_indices_cpu[i],
604
- )
605
-
606
- # unlock the last node
607
- self.tree_cache.dec_lock_ref(req.last_node)
614
+ self.tree_cache.cache_finished_req(req, cur_all_ids)
608
615
 
609
616
  # re-applying image padding
610
617
  if req.pixel_values is not None:
@@ -621,66 +628,74 @@ class Batch:
621
628
  jump_forward_reqs.append(req)
622
629
  filter_indices.remove(i)
623
630
 
624
- if len(filter_indices) < len(self.reqs):
625
- self.filter_batch(filter_indices)
631
+ self.filter_batch(filter_indices)
626
632
 
627
633
  return jump_forward_reqs
628
634
 
629
635
  def prepare_for_decode(self, input_ids=None):
630
636
  if input_ids is None:
631
637
  input_ids = [
632
- r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
638
+ r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
639
+ for r in self.reqs
633
640
  ]
641
+ else:
642
+ self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
643
+
634
644
  self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
635
645
  self.seq_lens.add_(1)
636
- self.prefix_lens = None
637
646
 
638
647
  # Alloc mem
639
- bs = len(self.reqs)
640
- self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
641
-
642
- if self.out_cache_loc is None:
643
- logger.error("Decode out of memory. Try to lower your batch size.")
644
- if self.tree_cache is not None:
645
- self.tree_cache.pretty_print()
646
- exit(1)
648
+ bs = self.batch_size()
649
+ self.out_cache_loc = self.alloc_token_slots(bs)
647
650
 
648
651
  self.req_to_token_pool.req_to_token[
649
652
  self.req_pool_indices, self.seq_lens - 1
650
653
  ] = self.out_cache_loc
651
654
 
652
655
  def filter_batch(self, unfinished_indices: List[int]):
656
+ if unfinished_indices is None or len(unfinished_indices) == 0:
657
+ # Filter out all requests
658
+ self.reqs = []
659
+ return
660
+
661
+ if len(unfinished_indices) == len(self.reqs):
662
+ # No need to filter
663
+ return
664
+
653
665
  self.reqs = [self.reqs[i] for i in unfinished_indices]
654
666
  new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
655
667
  self.seq_lens = self.seq_lens[new_indices]
656
668
  self.input_ids = None
657
669
  self.req_pool_indices = self.req_pool_indices[new_indices]
658
- self.prefix_lens = None
659
670
  self.position_ids_offsets = self.position_ids_offsets[new_indices]
660
671
  self.out_cache_loc = None
661
672
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
662
673
  self.return_logprob = any(req.return_logprob for req in self.reqs)
663
674
 
675
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
676
+
664
677
  for item in [
665
678
  "temperatures",
666
679
  "top_ps",
667
680
  "top_ks",
668
- "frequency_penalties",
669
- "presence_penalties",
670
681
  "logit_bias",
671
682
  ]:
672
683
  self_val = getattr(self, item, None)
673
684
  if self_val is not None: # logit_bias can be None
674
685
  setattr(self, item, self_val[new_indices])
675
686
 
676
- def merge(self, other: "Batch"):
687
+ def merge(self, other: "ScheduleBatch"):
688
+ # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
689
+ # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
690
+ # needs to be called with pre-merged Batch.reqs.
691
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
692
+
677
693
  self.reqs.extend(other.reqs)
678
694
 
679
695
  self.req_pool_indices = torch.concat(
680
696
  [self.req_pool_indices, other.req_pool_indices]
681
697
  )
682
698
  self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
683
- self.prefix_lens = None
684
699
  self.position_ids_offsets = torch.concat(
685
700
  [self.position_ids_offsets, other.position_ids_offsets]
686
701
  )
@@ -692,8 +707,6 @@ class Batch:
692
707
  "temperatures",
693
708
  "top_ps",
694
709
  "top_ks",
695
- "frequency_penalties",
696
- "presence_penalties",
697
710
  ]:
698
711
  self_val = getattr(self, item, None)
699
712
  other_val = getattr(other, item, None)
@@ -717,6 +730,7 @@ class Batch:
717
730
  self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
718
731
 
719
732
  def sample(self, logits: torch.Tensor):
733
+ # TODO(lsyin): move this into a part of layer and run with CUDA Graph
720
734
  # Post process logits
721
735
  logits = logits.contiguous()
722
736
  logits.div_(self.temperatures)
@@ -734,7 +748,8 @@ class Batch:
734
748
  ] = 1
735
749
  logits[i].masked_fill_(~allowed_mask, float("-inf"))
736
750
 
737
- # TODO(lmzheng): apply penalty
751
+ logits = self.penalizer_orchestrator.apply(logits)
752
+
738
753
  probs = torch.softmax(logits, dim=-1)
739
754
 
740
755
  if not global_server_args_dict["disable_flashinfer_sampling"]:
@@ -767,230 +782,9 @@ class Batch:
767
782
  req.regex_fsm_state, batch_next_token_ids_cpu[i]
768
783
  )
769
784
 
770
- return batch_next_token_ids
771
-
772
-
773
- @dataclass
774
- class InputMetadata:
775
- """Store all inforamtion of a forward pass."""
776
-
777
- forward_mode: ForwardMode
778
- batch_size: int
779
- total_num_tokens: int
780
- req_pool_indices: torch.Tensor
781
- seq_lens: torch.Tensor
782
- positions: torch.Tensor
783
- req_to_token_pool: ReqToTokenPool
784
- token_to_kv_pool: BaseTokenToKVPool
785
+ self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
785
786
 
786
- # For extend
787
- extend_seq_lens: torch.Tensor
788
- extend_start_loc: torch.Tensor
789
- extend_no_prefix: bool
790
-
791
- # Output location of the KV cache
792
- out_cache_loc: torch.Tensor = None
793
-
794
- # Output options
795
- return_logprob: bool = False
796
- top_logprobs_nums: List[int] = None
797
-
798
- # Trition attention backend
799
- triton_max_seq_len: int = 0
800
- triton_max_extend_len: int = 0
801
- triton_start_loc: torch.Tensor = None
802
- triton_prefix_lens: torch.Tensor = None
803
-
804
- # FlashInfer attention backend
805
- flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
806
- flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
807
- flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
808
- flashinfer_use_ragged: bool = False
809
-
810
- @classmethod
811
- def create(
812
- cls,
813
- model_runner,
814
- forward_mode,
815
- req_pool_indices,
816
- seq_lens,
817
- prefix_lens,
818
- position_ids_offsets,
819
- out_cache_loc,
820
- top_logprobs_nums=None,
821
- return_logprob=False,
822
- skip_flashinfer_init=False,
823
- ):
824
- flashinfer_use_ragged = False
825
- if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
826
- if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
827
- flashinfer_use_ragged = True
828
- init_flashinfer_args(
829
- forward_mode,
830
- model_runner,
831
- req_pool_indices,
832
- seq_lens,
833
- prefix_lens,
834
- model_runner.flashinfer_decode_wrapper,
835
- flashinfer_use_ragged,
836
- )
837
-
838
- batch_size = len(req_pool_indices)
839
-
840
- if forward_mode == ForwardMode.DECODE:
841
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
842
- extend_seq_lens = extend_start_loc = extend_no_prefix = None
843
- if not model_runner.server_args.disable_flashinfer:
844
- # This variable is not needed in this case,
845
- # we do not compute it to make it compatbile with cuda graph.
846
- total_num_tokens = None
847
- else:
848
- total_num_tokens = int(torch.sum(seq_lens))
849
- else:
850
- seq_lens_cpu = seq_lens.cpu().numpy()
851
- prefix_lens_cpu = prefix_lens.cpu().numpy()
852
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
853
- positions = torch.tensor(
854
- np.concatenate(
855
- [
856
- np.arange(
857
- prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
858
- seq_lens_cpu[i] + position_ids_offsets_cpu[i],
859
- )
860
- for i in range(batch_size)
861
- ],
862
- axis=0,
863
- ),
864
- device="cuda",
865
- )
866
- extend_seq_lens = seq_lens - prefix_lens
867
- extend_start_loc = torch.zeros_like(seq_lens)
868
- extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
869
- extend_no_prefix = torch.all(prefix_lens == 0)
870
- total_num_tokens = int(torch.sum(seq_lens))
871
-
872
- ret = cls(
873
- forward_mode=forward_mode,
874
- batch_size=batch_size,
875
- total_num_tokens=total_num_tokens,
876
- req_pool_indices=req_pool_indices,
877
- seq_lens=seq_lens,
878
- positions=positions,
879
- req_to_token_pool=model_runner.req_to_token_pool,
880
- token_to_kv_pool=model_runner.token_to_kv_pool,
881
- out_cache_loc=out_cache_loc,
882
- extend_seq_lens=extend_seq_lens,
883
- extend_start_loc=extend_start_loc,
884
- extend_no_prefix=extend_no_prefix,
885
- return_logprob=return_logprob,
886
- top_logprobs_nums=top_logprobs_nums,
887
- flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
888
- flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
889
- flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
890
- flashinfer_use_ragged=flashinfer_use_ragged,
891
- )
892
-
893
- if model_runner.server_args.disable_flashinfer:
894
- (
895
- ret.triton_max_seq_len,
896
- ret.triton_max_extend_len,
897
- ret.triton_start_loc,
898
- ret.triton_prefix_lens,
899
- ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
900
-
901
- return ret
902
-
903
-
904
- def init_flashinfer_args(
905
- forward_mode,
906
- model_runner,
907
- req_pool_indices,
908
- seq_lens,
909
- prefix_lens,
910
- flashinfer_decode_wrapper,
911
- flashinfer_use_ragged=False,
912
- ):
913
- """Init auxiliary variables for FlashInfer attention backend."""
914
- num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
915
- num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
916
- head_dim = model_runner.model_config.head_dim
917
- batch_size = len(req_pool_indices)
918
- total_num_tokens = int(torch.sum(seq_lens))
919
-
920
- if flashinfer_use_ragged:
921
- paged_kernel_lens = prefix_lens
922
- else:
923
- paged_kernel_lens = seq_lens
924
-
925
- kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
926
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
927
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
928
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
929
- kv_indices = torch.cat(
930
- [
931
- model_runner.req_to_token_pool.req_to_token[
932
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
933
- ]
934
- for i in range(batch_size)
935
- ],
936
- dim=0,
937
- ).contiguous()
938
- kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
939
-
940
- if forward_mode == ForwardMode.DECODE:
941
- flashinfer_decode_wrapper.end_forward()
942
- flashinfer_decode_wrapper.begin_forward(
943
- kv_indptr,
944
- kv_indices,
945
- kv_last_page_len,
946
- num_qo_heads,
947
- num_kv_heads,
948
- head_dim,
949
- 1,
950
- )
951
- else:
952
- # extend part
953
- qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
954
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
955
-
956
- if flashinfer_use_ragged:
957
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
958
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
959
- qo_indptr,
960
- qo_indptr,
961
- num_qo_heads,
962
- num_kv_heads,
963
- head_dim,
964
- )
965
-
966
- # cached part
967
- model_runner.flashinfer_prefill_wrapper_paged.end_forward()
968
- model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
969
- qo_indptr,
970
- kv_indptr,
971
- kv_indices,
972
- kv_last_page_len,
973
- num_qo_heads,
974
- num_kv_heads,
975
- head_dim,
976
- 1,
977
- )
978
-
979
-
980
- def init_triton_args(forward_mode, seq_lens, prefix_lens):
981
- """Init auxiliary variables for triton attention backend."""
982
- batch_size = len(seq_lens)
983
- max_seq_len = int(torch.max(seq_lens))
984
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
985
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
986
-
987
- if forward_mode == ForwardMode.DECODE:
988
- max_extend_len = None
989
- else:
990
- extend_seq_lens = seq_lens - prefix_lens
991
- max_extend_len = int(torch.max(extend_seq_lens))
992
-
993
- return max_seq_len, max_extend_len, start_loc, prefix_lens
787
+ return batch_next_token_ids
994
788
 
995
789
 
996
790
  def top_k_top_p_sampling_from_probs_torch(
@@ -1009,7 +803,7 @@ def top_k_top_p_sampling_from_probs_torch(
1009
803
  sampled_index = torch.multinomial(probs_sort, num_samples=1)
1010
804
  except RuntimeError:
1011
805
  batch_next_token_ids = torch.zeros(
1012
- (probs_sort.shape[0],), dtype=torch.int64, device=probs.device
806
+ (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
1013
807
  )
1014
808
  success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
1015
809
  return batch_next_token_ids, success