sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +234 -74
  4. sglang/check_env.py +25 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -40
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/hf_transformers_utils.py +2 -2
  13. sglang/srt/layers/extend_attention.py +59 -7
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/radix_attention.py +24 -14
  16. sglang/srt/layers/token_attention.py +28 -2
  17. sglang/srt/managers/io_struct.py +9 -4
  18. sglang/srt/managers/schedule_batch.py +98 -323
  19. sglang/srt/managers/tokenizer_manager.py +34 -16
  20. sglang/srt/managers/tp_worker.py +20 -22
  21. sglang/srt/mem_cache/memory_pool.py +74 -38
  22. sglang/srt/model_config.py +11 -0
  23. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  24. sglang/srt/model_executor/forward_batch_info.py +256 -0
  25. sglang/srt/model_executor/model_runner.py +51 -26
  26. sglang/srt/models/chatglm.py +1 -1
  27. sglang/srt/models/commandr.py +1 -1
  28. sglang/srt/models/dbrx.py +1 -1
  29. sglang/srt/models/deepseek.py +1 -1
  30. sglang/srt/models/deepseek_v2.py +199 -17
  31. sglang/srt/models/gemma.py +1 -1
  32. sglang/srt/models/gemma2.py +1 -1
  33. sglang/srt/models/gpt_bigcode.py +1 -1
  34. sglang/srt/models/grok.py +1 -1
  35. sglang/srt/models/internlm2.py +1 -1
  36. sglang/srt/models/llama2.py +1 -1
  37. sglang/srt/models/llama_classification.py +1 -1
  38. sglang/srt/models/llava.py +1 -2
  39. sglang/srt/models/llavavid.py +1 -2
  40. sglang/srt/models/minicpm.py +1 -1
  41. sglang/srt/models/mixtral.py +1 -1
  42. sglang/srt/models/mixtral_quant.py +1 -1
  43. sglang/srt/models/qwen.py +1 -1
  44. sglang/srt/models/qwen2.py +1 -1
  45. sglang/srt/models/qwen2_moe.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/openai_api/adapter.py +151 -29
  48. sglang/srt/openai_api/protocol.py +7 -1
  49. sglang/srt/server.py +111 -84
  50. sglang/srt/server_args.py +12 -2
  51. sglang/srt/utils.py +25 -20
  52. sglang/test/run_eval.py +21 -10
  53. sglang/test/runners.py +237 -0
  54. sglang/test/simple_eval_common.py +12 -12
  55. sglang/test/simple_eval_gpqa.py +92 -0
  56. sglang/test/simple_eval_humaneval.py +5 -5
  57. sglang/test/simple_eval_math.py +72 -0
  58. sglang/test/test_utils.py +95 -14
  59. sglang/utils.py +15 -37
  60. sglang/version.py +1 -1
  61. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
  62. sglang-0.2.11.dist-info/RECORD +102 -0
  63. sglang-0.2.9.post1.dist-info/RECORD +0 -97
  64. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  65. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  66. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
54
54
  att_stride_h,
55
55
  kv_group_num: tl.constexpr,
56
56
  BLOCK_DMODEL: tl.constexpr,
57
+ BLOCK_DPE: tl.constexpr,
57
58
  BLOCK_N: tl.constexpr,
58
59
  logit_cap: tl.constexpr,
59
60
  ):
@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
73
74
 
74
75
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
75
76
 
77
+ if BLOCK_DPE > 0:
78
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
79
+ off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
80
+
76
81
  offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
77
82
 
78
83
  block_stard_index = start_n * BLOCK_N
@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
97
102
  other=0.0,
98
103
  ).to(REDUCE_TRITON_TYPE)
99
104
  att_value = tl.sum(q[None, :] * k, 1)
105
+ if BLOCK_DPE > 0:
106
+ qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
107
+ offs_buf_kpe = (
108
+ k_loc[:, None] * stride_buf_kbs
109
+ + cur_kv_head * stride_buf_kh
110
+ + offs_dpe[None, :]
111
+ )
112
+ kpe = tl.load(
113
+ K_Buffer + offs_buf_kpe,
114
+ mask=offs_n_new[:, None] < cur_batch_end_index,
115
+ other=0.0,
116
+ ).to(REDUCE_TRITON_TYPE)
117
+ att_value += tl.sum(qpe[None, :] * kpe, 1)
100
118
  att_value *= sm_scale
101
119
 
102
120
  if logit_cap > 0:
@@ -192,7 +210,14 @@ def _token_att_m_fwd(
192
210
  # shape constraints
193
211
  Lq, Lk = q.shape[-1], k_buffer.shape[-1]
194
212
  assert Lq == Lk
195
- assert Lk in {16, 32, 64, 128, 256}
213
+ assert Lk in {16, 32, 64, 128, 256, 576}
214
+
215
+ if Lk == 576:
216
+ BLOCK_DMODEL = 512
217
+ BLOCK_DPE = 64
218
+ else:
219
+ BLOCK_DMODEL = Lk
220
+ BLOCK_DPE = 0
196
221
 
197
222
  batch, head_num = B_req_idx.shape[0], q.shape[1]
198
223
 
@@ -220,7 +245,8 @@ def _token_att_m_fwd(
220
245
  k_buffer.stride(1),
221
246
  att_out.stride(0),
222
247
  kv_group_num=kv_group_num,
223
- BLOCK_DMODEL=Lk,
248
+ BLOCK_DMODEL=BLOCK_DMODEL,
249
+ BLOCK_DPE=BLOCK_DPE,
224
250
  BLOCK_N=BLOCK,
225
251
  logit_cap=logit_cap,
226
252
  num_warps=num_warps,
@@ -92,7 +92,7 @@ class GenerateReqInput:
92
92
  for element in parallel_sample_num_list
93
93
  )
94
94
  if parallel_sample_num > 1 and (not all_equal):
95
- ## TODO cope with the case that the parallel_sample_num is different for different samples
95
+ # TODO cope with the case that the parallel_sample_num is different for different samples
96
96
  raise ValueError(
97
97
  "The parallel_sample_num should be the same for all samples in sample params."
98
98
  )
@@ -103,14 +103,19 @@ class GenerateReqInput:
103
103
  if parallel_sample_num != 1:
104
104
  # parallel sampling +1 represents the original prefill stage
105
105
  num = parallel_sample_num + 1
106
- if isinstance(self.text, List):
107
- ## suppot batch operation
106
+ if isinstance(self.text, list):
107
+ # suppot batch operation
108
108
  self.batch_size = len(self.text)
109
109
  num = num * len(self.text)
110
+ elif isinstance(self.input_ids, list) and isinstance(
111
+ self.input_ids[0], list
112
+ ):
113
+ self.batch_size = len(self.input_ids)
114
+ num = num * len(self.input_ids)
110
115
  else:
111
116
  self.batch_size = 1
112
117
  else:
113
- ## support select operation
118
+ # support select operation
114
119
  num = len(self.text) if self.text is not None else len(self.input_ids)
115
120
  self.batch_size = num
116
121
 
@@ -18,7 +18,6 @@ limitations under the License.
18
18
  import logging
19
19
  import warnings
20
20
  from dataclasses import dataclass
21
- from enum import IntEnum, auto
22
21
  from typing import List, Union
23
22
 
24
23
  import numpy as np
@@ -29,7 +28,7 @@ from sglang.global_config import global_config
29
28
  from sglang.srt.constrained import RegexGuide
30
29
  from sglang.srt.constrained.jump_forward import JumpForwardMap
31
30
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
32
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
31
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
33
32
  from sglang.srt.mem_cache.radix_cache import RadixCache
34
33
 
35
34
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -39,21 +38,13 @@ global_server_args_dict = {
39
38
  "disable_flashinfer": False,
40
39
  "disable_flashinfer_sampling": False,
41
40
  "attention_reduce_in_fp32": False,
41
+ "enable_mla": False,
42
42
  }
43
43
 
44
44
 
45
45
  logger = logging.getLogger(__name__)
46
46
 
47
47
 
48
- class ForwardMode(IntEnum):
49
- # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
50
- PREFILL = auto()
51
- # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
52
- EXTEND = auto()
53
- # Decode one token.
54
- DECODE = auto()
55
-
56
-
57
48
  class BaseFinishReason:
58
49
  def __init__(self, is_error: bool = False):
59
50
  self.is_error = is_error
@@ -109,6 +100,9 @@ class Req:
109
100
  self.output_ids = [] # Each decode stage's output ids
110
101
  self.input_ids = None # input_ids = origin_input_ids + output_ids
111
102
 
103
+ # Memory info
104
+ self.req_pool_idx = None
105
+
112
106
  # For incremental decoding
113
107
  # ----- | --------- read_ids -------|
114
108
  # ----- | surr_ids |
@@ -283,13 +277,13 @@ class Req:
283
277
 
284
278
 
285
279
  @dataclass
286
- class Batch:
280
+ class ScheduleBatch:
287
281
  """Store all inforamtion of a batch."""
288
282
 
289
283
  # Request, memory pool, and cache
290
284
  reqs: List[Req]
291
285
  req_to_token_pool: ReqToTokenPool
292
- token_to_kv_pool: TokenToKVPool
286
+ token_to_kv_pool: BaseTokenToKVPool
293
287
  tree_cache: RadixCache
294
288
 
295
289
  # Batched arguments to model runner
@@ -330,6 +324,9 @@ class Batch:
330
324
  return_logprob=return_logprob,
331
325
  )
332
326
 
327
+ def batch_size(self):
328
+ return len(self.reqs) if self.reqs is not None else 0
329
+
333
330
  def is_empty(self):
334
331
  return len(self.reqs) == 0
335
332
 
@@ -337,116 +334,127 @@ class Batch:
337
334
  # Return whether batch has at least 1 streaming request
338
335
  return any(r.stream for r in self.reqs)
339
336
 
337
+ def alloc_req_slots(self, num_reqs):
338
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
339
+ if req_pool_indices is None:
340
+ raise RuntimeError(
341
+ "Out of memory. "
342
+ "Please set a smaller number for `--max-running-requests`."
343
+ )
344
+ return req_pool_indices
345
+
346
+ def alloc_token_slots(self, num_tokens: int):
347
+ out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
348
+
349
+ if out_cache_loc is None:
350
+ if self.tree_cache is not None:
351
+ self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
352
+ out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
353
+
354
+ if out_cache_loc is None:
355
+ logger.error("Prefill out of memory. Try to lower your batch size.")
356
+ if self.tree_cache is not None:
357
+ self.tree_cache.pretty_print()
358
+ exit(1)
359
+
360
+ return out_cache_loc
361
+
362
+ def batch_sampling_params(self, vocab_size, int_token_logit_bias):
363
+ device = "cuda"
364
+ bs, reqs = self.batch_size(), self.reqs
365
+ self.temperatures = torch.tensor(
366
+ [r.sampling_params.temperature for r in reqs],
367
+ dtype=torch.float,
368
+ device=device,
369
+ ).view(-1, 1)
370
+ self.top_ps = torch.tensor(
371
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
372
+ )
373
+ self.top_ks = torch.tensor(
374
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
375
+ )
376
+ self.frequency_penalties = torch.tensor(
377
+ [r.sampling_params.frequency_penalty for r in reqs],
378
+ dtype=torch.float,
379
+ device=device,
380
+ )
381
+ self.presence_penalties = torch.tensor(
382
+ [r.sampling_params.presence_penalty for r in reqs],
383
+ dtype=torch.float,
384
+ device=device,
385
+ )
386
+
387
+ # Handle logit bias but only allocate when needed
388
+ self.logit_bias = None
389
+ for i in range(bs):
390
+ if reqs[i].sampling_params.dtype == "int":
391
+ if self.logit_bias is None:
392
+ self.logit_bias = torch.zeros(
393
+ (bs, vocab_size), dtype=torch.float32, device=device
394
+ )
395
+ self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
396
+
340
397
  def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
341
398
  device = "cuda"
342
- bs = len(self.reqs)
399
+ bs = self.batch_size()
343
400
  reqs = self.reqs
344
401
  input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
345
402
  prefix_indices = [r.prefix_indices for r in reqs]
346
403
 
347
404
  # Handle prefix
348
- flatten_input_ids = []
349
405
  extend_lens = []
350
406
  prefix_lens = []
351
407
  seq_lens = []
352
408
 
353
- req_pool_indices = self.req_to_token_pool.alloc(bs)
354
-
355
- if req_pool_indices is None:
356
- raise RuntimeError(
357
- "Out of memory. "
358
- "Please set a smaller number for `--max-running-requests`."
359
- )
409
+ req_pool_indices_cpu = self.alloc_req_slots(bs)
360
410
 
361
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
362
- for i in range(bs):
363
- flatten_input_ids.extend(input_ids[i])
411
+ for i, req in enumerate(reqs):
412
+ req.req_pool_idx = req_pool_indices_cpu[i]
364
413
  extend_lens.append(len(input_ids[i]))
365
414
 
366
415
  if len(prefix_indices[i]) == 0:
367
416
  prefix_lens.append(0)
368
417
  else:
369
418
  prefix_lens.append(len(prefix_indices[i]))
370
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
419
+ self.req_to_token_pool.req_to_token[req.req_pool_idx][
371
420
  : len(prefix_indices[i])
372
421
  ] = prefix_indices[i]
373
422
 
374
423
  seq_lens.append(prefix_lens[-1] + extend_lens[-1])
375
424
 
376
- position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
377
-
378
425
  # Allocate memory
379
426
  seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
380
427
  extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
381
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
382
- if out_cache_loc is None:
383
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
384
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
385
-
386
- if out_cache_loc is None:
387
- logger.error("Prefill out of memory. This should never happen.")
388
- self.tree_cache.pretty_print()
389
- exit()
428
+ out_cache_loc = self.alloc_token_slots(extend_num_tokens)
390
429
 
391
430
  pt = 0
392
- for i in range(bs):
393
- self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
431
+ for i, req in enumerate(reqs):
432
+ self.req_to_token_pool.req_to_token[req.req_pool_idx][
394
433
  prefix_lens[i] : prefix_lens[i] + extend_lens[i]
395
434
  ] = out_cache_loc[pt : pt + extend_lens[i]]
396
435
  pt += extend_lens[i]
397
436
 
398
- # Handle logit bias but only allocate when needed
399
- logit_bias = None
400
- for i in range(bs):
401
- if reqs[i].sampling_params.dtype == "int":
402
- if logit_bias is None:
403
- logit_bias = torch.zeros(
404
- (bs, vocab_size), dtype=torch.float32, device=device
405
- )
406
- logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
407
-
408
437
  # Set fields
409
- self.input_ids = torch.tensor(
410
- flatten_input_ids, dtype=torch.int32, device=device
411
- )
438
+ with torch.device("cuda"):
439
+ self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
440
+ self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
441
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
442
+ self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
443
+
412
444
  self.pixel_values = [r.pixel_values for r in reqs]
413
445
  self.image_sizes = [r.image_size for r in reqs]
414
446
  self.image_offsets = [
415
447
  r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
416
448
  ]
417
- self.req_pool_indices = req_pool_indices
418
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
419
449
  self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
420
- self.position_ids_offsets = position_ids_offsets
421
450
  self.extend_num_tokens = extend_num_tokens
422
451
  self.out_cache_loc = out_cache_loc
423
452
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
424
453
 
425
- self.temperatures = torch.tensor(
426
- [r.sampling_params.temperature for r in reqs],
427
- dtype=torch.float,
428
- device=device,
429
- ).view(-1, 1)
430
- self.top_ps = torch.tensor(
431
- [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
432
- )
433
- self.top_ks = torch.tensor(
434
- [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
435
- )
436
- self.frequency_penalties = torch.tensor(
437
- [r.sampling_params.frequency_penalty for r in reqs],
438
- dtype=torch.float,
439
- device=device,
440
- )
441
- self.presence_penalties = torch.tensor(
442
- [r.sampling_params.presence_penalty for r in reqs],
443
- dtype=torch.float,
444
- device=device,
445
- )
446
- self.logit_bias = logit_bias
454
+ self.batch_sampling_params(vocab_size, int_token_logit_bias)
447
455
 
448
456
  def check_decode_mem(self):
449
- bs = len(self.reqs)
457
+ bs = self.batch_size()
450
458
  if self.token_to_kv_pool.available_size() >= bs:
451
459
  return True
452
460
 
@@ -471,7 +479,6 @@ class Batch:
471
479
 
472
480
  retracted_reqs = []
473
481
  seq_lens_cpu = self.seq_lens.cpu().numpy()
474
- req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
475
482
  while (
476
483
  self.token_to_kv_pool.available_size()
477
484
  < len(sorted_indices) * global_config.retract_decode_steps
@@ -489,20 +496,20 @@ class Batch:
489
496
 
490
497
  if isinstance(self.tree_cache, ChunkCache):
491
498
  # ChunkCache does not have eviction
492
- token_indices = self.req_to_token_pool.req_to_token[
493
- req_pool_indices_cpu[idx]
494
- ][: seq_lens_cpu[idx]]
499
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
500
+ : seq_lens_cpu[idx]
501
+ ]
495
502
  self.token_to_kv_pool.free(token_indices)
496
- self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
503
+ self.req_to_token_pool.free(req.req_pool_idx)
497
504
  del self.tree_cache.entries[req.rid]
498
505
  else:
499
506
  # TODO: apply more fine-grained retraction
500
507
  last_uncached_pos = len(req.prefix_indices)
501
- token_indices = self.req_to_token_pool.req_to_token[
502
- req_pool_indices_cpu[idx]
503
- ][last_uncached_pos : seq_lens_cpu[idx]]
508
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
509
+ last_uncached_pos : seq_lens_cpu[idx]
510
+ ]
504
511
  self.token_to_kv_pool.free(token_indices)
505
- self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
512
+ self.req_to_token_pool.free(req.req_pool_idx)
506
513
 
507
514
  # release the last node
508
515
  self.tree_cache.dec_lock_ref(req.last_node)
@@ -540,8 +547,6 @@ class Batch:
540
547
  jump_forward_reqs = []
541
548
  filter_indices = [i for i in range(len(self.reqs))]
542
549
 
543
- req_pool_indices_cpu = None
544
-
545
550
  for i, req in enumerate(self.reqs):
546
551
  if req.jump_forward_map is not None:
547
552
  jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
@@ -591,13 +596,11 @@ class Batch:
591
596
  req.vid += 1
592
597
 
593
598
  # insert the old request into tree_cache
594
- if req_pool_indices_cpu is None:
595
- req_pool_indices_cpu = self.req_pool_indices.tolist()
596
599
  self.tree_cache.cache_req(
597
600
  rid=req.rid,
598
601
  token_ids=cur_all_ids,
599
602
  last_uncached_pos=len(req.prefix_indices),
600
- req_pool_idx=req_pool_indices_cpu[i],
603
+ req_pool_idx=req.req_pool_idx,
601
604
  )
602
605
 
603
606
  # unlock the last node
@@ -633,13 +636,8 @@ class Batch:
633
636
  self.prefix_lens = None
634
637
 
635
638
  # Alloc mem
636
- bs = len(self.reqs)
637
- self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
638
-
639
- if self.out_cache_loc is None:
640
- logger.error("Decode out of memory. This should never happen.")
641
- self.tree_cache.pretty_print()
642
- exit()
639
+ bs = self.batch_size()
640
+ self.out_cache_loc = self.alloc_token_slots(bs)
643
641
 
644
642
  self.req_to_token_pool.req_to_token[
645
643
  self.req_pool_indices, self.seq_lens - 1
@@ -669,7 +667,7 @@ class Batch:
669
667
  if self_val is not None: # logit_bias can be None
670
668
  setattr(self, item, self_val[new_indices])
671
669
 
672
- def merge(self, other: "Batch"):
670
+ def merge(self, other: "ScheduleBatch"):
673
671
  self.reqs.extend(other.reqs)
674
672
 
675
673
  self.req_pool_indices = torch.concat(
@@ -766,229 +764,6 @@ class Batch:
766
764
  return batch_next_token_ids
767
765
 
768
766
 
769
- @dataclass
770
- class InputMetadata:
771
- """Store all inforamtion of a forward pass."""
772
-
773
- forward_mode: ForwardMode
774
- batch_size: int
775
- total_num_tokens: int
776
- req_pool_indices: torch.Tensor
777
- seq_lens: torch.Tensor
778
- positions: torch.Tensor
779
- req_to_token_pool: ReqToTokenPool
780
- token_to_kv_pool: TokenToKVPool
781
-
782
- # For extend
783
- extend_seq_lens: torch.Tensor
784
- extend_start_loc: torch.Tensor
785
- extend_no_prefix: bool
786
-
787
- # Output location of the KV cache
788
- out_cache_loc: torch.Tensor = None
789
-
790
- # Output options
791
- return_logprob: bool = False
792
- top_logprobs_nums: List[int] = None
793
-
794
- # Trition attention backend
795
- triton_max_seq_len: int = 0
796
- triton_max_extend_len: int = 0
797
- triton_start_loc: torch.Tensor = None
798
- triton_prefix_lens: torch.Tensor = None
799
-
800
- # FlashInfer attention backend
801
- flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
802
- flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
803
- flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
804
- flashinfer_use_ragged: bool = False
805
-
806
- @classmethod
807
- def create(
808
- cls,
809
- model_runner,
810
- forward_mode,
811
- req_pool_indices,
812
- seq_lens,
813
- prefix_lens,
814
- position_ids_offsets,
815
- out_cache_loc,
816
- top_logprobs_nums=None,
817
- return_logprob=False,
818
- skip_flashinfer_init=False,
819
- ):
820
- flashinfer_use_ragged = False
821
- if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
822
- if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
823
- flashinfer_use_ragged = True
824
- init_flashinfer_args(
825
- forward_mode,
826
- model_runner,
827
- req_pool_indices,
828
- seq_lens,
829
- prefix_lens,
830
- model_runner.flashinfer_decode_wrapper,
831
- flashinfer_use_ragged,
832
- )
833
-
834
- batch_size = len(req_pool_indices)
835
-
836
- if forward_mode == ForwardMode.DECODE:
837
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
838
- extend_seq_lens = extend_start_loc = extend_no_prefix = None
839
- if not model_runner.server_args.disable_flashinfer:
840
- # This variable is not needed in this case,
841
- # we do not compute it to make it compatbile with cuda graph.
842
- total_num_tokens = None
843
- else:
844
- total_num_tokens = int(torch.sum(seq_lens))
845
- else:
846
- seq_lens_cpu = seq_lens.cpu().numpy()
847
- prefix_lens_cpu = prefix_lens.cpu().numpy()
848
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
849
- positions = torch.tensor(
850
- np.concatenate(
851
- [
852
- np.arange(
853
- prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
854
- seq_lens_cpu[i] + position_ids_offsets_cpu[i],
855
- )
856
- for i in range(batch_size)
857
- ],
858
- axis=0,
859
- ),
860
- device="cuda",
861
- )
862
- extend_seq_lens = seq_lens - prefix_lens
863
- extend_start_loc = torch.zeros_like(seq_lens)
864
- extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
865
- extend_no_prefix = torch.all(prefix_lens == 0)
866
- total_num_tokens = int(torch.sum(seq_lens))
867
-
868
- ret = cls(
869
- forward_mode=forward_mode,
870
- batch_size=batch_size,
871
- total_num_tokens=total_num_tokens,
872
- req_pool_indices=req_pool_indices,
873
- seq_lens=seq_lens,
874
- positions=positions,
875
- req_to_token_pool=model_runner.req_to_token_pool,
876
- token_to_kv_pool=model_runner.token_to_kv_pool,
877
- out_cache_loc=out_cache_loc,
878
- extend_seq_lens=extend_seq_lens,
879
- extend_start_loc=extend_start_loc,
880
- extend_no_prefix=extend_no_prefix,
881
- return_logprob=return_logprob,
882
- top_logprobs_nums=top_logprobs_nums,
883
- flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
884
- flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
885
- flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
886
- flashinfer_use_ragged=flashinfer_use_ragged,
887
- )
888
-
889
- if model_runner.server_args.disable_flashinfer:
890
- (
891
- ret.triton_max_seq_len,
892
- ret.triton_max_extend_len,
893
- ret.triton_start_loc,
894
- ret.triton_prefix_lens,
895
- ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
896
-
897
- return ret
898
-
899
-
900
- def init_flashinfer_args(
901
- forward_mode,
902
- model_runner,
903
- req_pool_indices,
904
- seq_lens,
905
- prefix_lens,
906
- flashinfer_decode_wrapper,
907
- flashinfer_use_ragged=False,
908
- ):
909
- """Init auxiliary variables for FlashInfer attention backend."""
910
- num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
911
- num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
912
- head_dim = model_runner.model_config.head_dim
913
- batch_size = len(req_pool_indices)
914
- total_num_tokens = int(torch.sum(seq_lens))
915
-
916
- if flashinfer_use_ragged:
917
- paged_kernel_lens = prefix_lens
918
- else:
919
- paged_kernel_lens = seq_lens
920
-
921
- kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
922
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
923
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
924
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
925
- kv_indices = torch.cat(
926
- [
927
- model_runner.req_to_token_pool.req_to_token[
928
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
929
- ]
930
- for i in range(batch_size)
931
- ],
932
- dim=0,
933
- ).contiguous()
934
- kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
935
-
936
- if forward_mode == ForwardMode.DECODE:
937
- flashinfer_decode_wrapper.end_forward()
938
- flashinfer_decode_wrapper.begin_forward(
939
- kv_indptr,
940
- kv_indices,
941
- kv_last_page_len,
942
- num_qo_heads,
943
- num_kv_heads,
944
- head_dim,
945
- 1,
946
- )
947
- else:
948
- # extend part
949
- qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
950
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
951
-
952
- if flashinfer_use_ragged:
953
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
954
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
955
- qo_indptr,
956
- qo_indptr,
957
- num_qo_heads,
958
- num_kv_heads,
959
- head_dim,
960
- )
961
-
962
- # cached part
963
- model_runner.flashinfer_prefill_wrapper_paged.end_forward()
964
- model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
965
- qo_indptr,
966
- kv_indptr,
967
- kv_indices,
968
- kv_last_page_len,
969
- num_qo_heads,
970
- num_kv_heads,
971
- head_dim,
972
- 1,
973
- )
974
-
975
-
976
- def init_triton_args(forward_mode, seq_lens, prefix_lens):
977
- """Init auxiliary variables for triton attention backend."""
978
- batch_size = len(seq_lens)
979
- max_seq_len = int(torch.max(seq_lens))
980
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
981
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
982
-
983
- if forward_mode == ForwardMode.DECODE:
984
- max_extend_len = None
985
- else:
986
- extend_seq_lens = seq_lens - prefix_lens
987
- max_extend_len = int(torch.max(extend_seq_lens))
988
-
989
- return max_seq_len, max_extend_len, start_loc, prefix_lens
990
-
991
-
992
767
  def top_k_top_p_sampling_from_probs_torch(
993
768
  probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
994
769
  ):