sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +337 -0
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +115 -31
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/base_grammar_backend.py +4 -3
  8. sglang/srt/constrained/outlines_backend.py +39 -26
  9. sglang/srt/constrained/xgrammar_backend.py +58 -14
  10. sglang/srt/layers/activation.py +3 -0
  11. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  12. sglang/srt/layers/attention/triton_backend.py +9 -7
  13. sglang/srt/layers/custom_op_util.py +26 -0
  14. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  15. sglang/srt/layers/fused_moe/patch.py +4 -2
  16. sglang/srt/layers/layernorm.py +4 -0
  17. sglang/srt/layers/logits_processor.py +10 -10
  18. sglang/srt/layers/sampler.py +4 -8
  19. sglang/srt/layers/torchao_utils.py +2 -0
  20. sglang/srt/managers/data_parallel_controller.py +74 -9
  21. sglang/srt/managers/detokenizer_manager.py +1 -14
  22. sglang/srt/managers/io_struct.py +27 -0
  23. sglang/srt/managers/schedule_batch.py +104 -38
  24. sglang/srt/managers/schedule_policy.py +5 -1
  25. sglang/srt/managers/scheduler.py +210 -56
  26. sglang/srt/managers/session_controller.py +62 -0
  27. sglang/srt/managers/tokenizer_manager.py +38 -0
  28. sglang/srt/managers/tp_worker.py +12 -1
  29. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  30. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  31. sglang/srt/model_executor/forward_batch_info.py +109 -15
  32. sglang/srt/model_executor/model_runner.py +102 -43
  33. sglang/srt/model_parallel.py +98 -0
  34. sglang/srt/models/deepseek_v2.py +147 -44
  35. sglang/srt/models/gemma2.py +9 -8
  36. sglang/srt/models/llava.py +1 -1
  37. sglang/srt/models/llavavid.py +1 -1
  38. sglang/srt/models/olmo.py +3 -3
  39. sglang/srt/models/phi3_small.py +447 -0
  40. sglang/srt/models/qwen2_vl.py +13 -6
  41. sglang/srt/models/torch_native_llama.py +94 -78
  42. sglang/srt/openai_api/adapter.py +11 -4
  43. sglang/srt/openai_api/protocol.py +30 -27
  44. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  45. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  47. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  48. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  49. sglang/srt/sampling/sampling_batch_info.py +58 -57
  50. sglang/srt/sampling/sampling_params.py +3 -3
  51. sglang/srt/server.py +29 -2
  52. sglang/srt/server_args.py +97 -60
  53. sglang/srt/utils.py +103 -51
  54. sglang/test/runners.py +25 -6
  55. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  56. sglang/test/test_utils.py +33 -22
  57. sglang/version.py +1 -1
  58. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  59. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
  60. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  61. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  62. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ import os
23
23
  import signal
24
24
  import sys
25
25
  import time
26
+ import uuid
26
27
  from typing import Dict, List, Optional, Tuple, Union
27
28
 
28
29
  import fastapi
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
42
43
  BatchEmbeddingOut,
43
44
  BatchStrOut,
44
45
  BatchTokenIDOut,
46
+ CloseSessionReqInput,
45
47
  EmbeddingReqInput,
46
48
  FlushCacheReq,
47
49
  GenerateReqInput,
48
50
  GetMemPoolSizeReq,
49
51
  GetMemPoolSizeReqOutput,
52
+ OpenSessionReqInput,
53
+ OpenSessionReqOutput,
50
54
  ProfileReq,
51
55
  TokenizedEmbeddingReqInput,
52
56
  TokenizedGenerateReqInput,
@@ -146,6 +150,9 @@ class TokenizerManager:
146
150
  self.model_update_lock = asyncio.Lock()
147
151
  self.model_update_result = None
148
152
 
153
+ # For session info
154
+ self.session_futures = {} # session_id -> asyncio event
155
+
149
156
  # Others
150
157
  self.gracefully_exit = False
151
158
 
@@ -211,6 +218,8 @@ class TokenizerManager:
211
218
  return_logprob = obj.return_logprob
212
219
  logprob_start_len = obj.logprob_start_len
213
220
  top_logprobs_num = obj.top_logprobs_num
221
+ session_id = obj.session_id
222
+ session_rid = obj.session_rid
214
223
 
215
224
  if len(input_ids) >= self.context_len:
216
225
  raise ValueError(
@@ -236,6 +245,8 @@ class TokenizerManager:
236
245
  top_logprobs_num,
237
246
  obj.stream,
238
247
  obj.lora_path,
248
+ session_id=session_id,
249
+ session_rid=session_rid,
239
250
  )
240
251
  elif isinstance(obj, EmbeddingReqInput):
241
252
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -451,6 +462,26 @@ class TokenizerManager:
451
462
  else:
452
463
  return False, "Another update is in progress. Please try again later."
453
464
 
465
+ async def open_session(
466
+ self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
467
+ ):
468
+ if self.to_create_loop:
469
+ self.create_handle_loop()
470
+
471
+ session_id = uuid.uuid4().hex
472
+ obj.session_id = session_id
473
+ self.send_to_scheduler.send_pyobj(obj)
474
+ self.session_futures[session_id] = asyncio.Future()
475
+ session_id = await self.session_futures[session_id]
476
+ del self.session_futures[session_id]
477
+ return session_id
478
+
479
+ async def close_session(
480
+ self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
481
+ ):
482
+ assert not self.to_create_loop, "close session should not be the first request"
483
+ await self.send_to_scheduler.send_pyobj(obj)
484
+
454
485
  def create_abort_task(self, obj: GenerateReqInput):
455
486
  # Abort the request if the client is disconnected.
456
487
  async def abort_request():
@@ -521,6 +552,11 @@ class TokenizerManager:
521
552
  if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
522
553
  self.mem_pool_size.set_result(self.mem_pool_size_tmp)
523
554
  continue
555
+ elif isinstance(recv_obj, OpenSessionReqOutput):
556
+ self.session_futures[recv_obj.session_id].set_result(
557
+ recv_obj.session_id
558
+ )
559
+ continue
524
560
 
525
561
  assert isinstance(
526
562
  recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
@@ -536,11 +572,13 @@ class TokenizerManager:
536
572
  out_dict = {
537
573
  "text": recv_obj.output_strs[i],
538
574
  "meta_info": recv_obj.meta_info[i],
575
+ "session_id": recv_obj.session_ids[i],
539
576
  }
540
577
  elif isinstance(recv_obj, BatchTokenIDOut):
541
578
  out_dict = {
542
579
  "token_ids": recv_obj.output_ids[i],
543
580
  "meta_info": recv_obj.meta_info[i],
581
+ "session_id": recv_obj.session_ids[i],
544
582
  }
545
583
  else:
546
584
  assert isinstance(recv_obj, BatchEmbeddingOut)
@@ -16,6 +16,7 @@ limitations under the License.
16
16
  """A tensor parallel worker."""
17
17
 
18
18
  import logging
19
+ import threading
19
20
  from typing import Optional
20
21
 
21
22
  from sglang.srt.configs.model_config import ModelConfig
@@ -134,9 +135,19 @@ class TpModelWorker:
134
135
  self.model_runner.token_to_kv_pool,
135
136
  )
136
137
 
137
- def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
138
+ def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
139
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
140
+ self.model_runner.forward(forward_batch)
141
+
142
+ def forward_batch_generation(
143
+ self,
144
+ model_worker_batch: ModelWorkerBatch,
145
+ launch_done: Optional[threading.Event] = None,
146
+ ):
138
147
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
139
148
  logits_output = self.model_runner.forward(forward_batch)
149
+ if launch_done:
150
+ launch_done.set()
140
151
  next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
141
152
  return logits_output, next_token_ids
142
153
 
@@ -15,9 +15,9 @@ limitations under the License.
15
15
 
16
16
  """A tensor parallel worker."""
17
17
 
18
+ import dataclasses
18
19
  import logging
19
20
  import threading
20
- import time
21
21
  from queue import Queue
22
22
  from typing import Optional
23
23
 
@@ -26,7 +26,6 @@ import torch
26
26
  from sglang.srt.managers.io_struct import UpdateWeightReqInput
27
27
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
28
28
  from sglang.srt.managers.tp_worker import TpModelWorker
29
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
29
  from sglang.srt.server_args import ServerArgs
31
30
 
32
31
  logger = logging.getLogger(__name__)
@@ -56,6 +55,7 @@ class TpModelWorkerClient:
56
55
  self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
57
56
  self.max_running_requests = self.worker.max_running_requests
58
57
  self.device = self.worker.device
58
+ self.gpu_id = gpu_id
59
59
 
60
60
  # Init future mappings
61
61
  self.future_token_ids_ct = 0
@@ -73,12 +73,6 @@ class TpModelWorkerClient:
73
73
  )
74
74
  self.forward_thread.start()
75
75
 
76
- self.copy_queue = Queue()
77
- self.copy_thread = threading.Thread(
78
- target=self.copy_thread_func,
79
- )
80
- self.copy_thread.start()
81
-
82
76
  def get_worker_info(self):
83
77
  return self.worker.get_worker_info()
84
78
 
@@ -98,15 +92,25 @@ class TpModelWorkerClient:
98
92
  with torch.cuda.stream(self.forward_stream):
99
93
  self.forward_thread_func_()
100
94
 
101
- @torch.inference_mode()
95
+ @torch.no_grad()
102
96
  def forward_thread_func_(self):
97
+ batch_pt = 0
98
+ batch_lists = [None] * 2
99
+
103
100
  while True:
104
- self.has_inflight_batch = False
105
101
  model_worker_batch, future_token_ids_ct = self.input_queue.get()
106
102
  if not model_worker_batch:
107
103
  break
108
- self.has_inflight_batch = True
109
- self.launch_event = threading.Event()
104
+
105
+ # Keep a reference of model_worker_batch by storing it into a list.
106
+ # Otherwise, the tensor members of model_worker_batch will be released
107
+ # by pytorch and cause CUDA illegal memory access errors.
108
+ batch_lists[batch_pt % 2] = model_worker_batch
109
+ batch_pt += 1
110
+
111
+ # Create event
112
+ self.launch_done = threading.Event()
113
+ copy_done = torch.cuda.Event()
110
114
 
111
115
  # Resolve future tokens in the input
112
116
  input_ids = model_worker_batch.input_ids
@@ -114,7 +118,7 @@ class TpModelWorkerClient:
114
118
 
115
119
  # Run forward
116
120
  logits_output, next_token_ids = self.worker.forward_batch_generation(
117
- model_worker_batch
121
+ model_worker_batch, self.launch_done
118
122
  )
119
123
 
120
124
  # Update the future token ids map
@@ -139,44 +143,45 @@ class TpModelWorkerClient:
139
143
  )
140
144
  )
141
145
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
142
- copy_event = torch.cuda.Event(blocking=True)
143
- copy_event.record()
146
+ copy_done.record()
144
147
 
145
- self.launch_event.set()
146
- self.copy_queue.put((copy_event, logits_output, next_token_ids))
148
+ self.output_queue.put((copy_done, logits_output, next_token_ids))
147
149
 
148
- def copy_thread_func(self):
149
- while True:
150
- copy_event, logits_output, next_token_ids = self.copy_queue.get()
151
- if not copy_event:
152
- break
153
- while not copy_event.query():
154
- time.sleep(1e-5)
150
+ def resolve_batch_result(self, bid: int):
151
+ copy_done, logits_output, next_token_ids = self.output_queue.get()
152
+ copy_done.synchronize()
153
+ self.launch_done.wait()
155
154
 
156
- if logits_output.next_token_logprobs is not None:
157
- logits_output.next_token_logprobs = (
158
- logits_output.next_token_logprobs.tolist()
155
+ if logits_output.next_token_logprobs is not None:
156
+ logits_output.next_token_logprobs = (
157
+ logits_output.next_token_logprobs.tolist()
158
+ )
159
+ if logits_output.input_token_logprobs is not None:
160
+ logits_output.input_token_logprobs = (
161
+ logits_output.input_token_logprobs.tolist()
159
162
  )
160
- if logits_output.input_token_logprobs is not None:
161
- logits_output.input_token_logprobs = (
162
- logits_output.input_token_logprobs.tolist()
163
- )
164
- logits_output.normalized_prompt_logprobs = (
165
- logits_output.normalized_prompt_logprobs.tolist()
166
- )
167
-
168
- self.output_queue.put((logits_output, next_token_ids.tolist()))
169
-
170
- def resulve_batch_result(self, bid: int):
171
- logits_output, next_token_ids = self.output_queue.get()
172
- if self.has_inflight_batch:
173
- # Wait until the batch is launched
174
- self.launch_event.wait()
163
+ logits_output.normalized_prompt_logprobs = (
164
+ logits_output.normalized_prompt_logprobs.tolist()
165
+ )
166
+ next_token_ids = next_token_ids.tolist()
175
167
  return logits_output, next_token_ids
176
168
 
177
169
  def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
170
+ # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
171
+ sampling_info = model_worker_batch.sampling_info
172
+ sampling_info.update_penalties()
173
+ model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
174
+ sampling_info,
175
+ sampling_info_done=threading.Event(),
176
+ scaling_penalties=sampling_info.scaling_penalties,
177
+ linear_penalties=sampling_info.linear_penalties,
178
+ )
179
+
180
+ # A cuda stream sync here to avoid the cuda illegal memory access error.
181
+ torch.cuda.current_stream().synchronize()
182
+
178
183
  # Push a new batch to the queue
179
- self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
184
+ self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
180
185
 
181
186
  # Allocate output future objects
182
187
  bs = len(model_worker_batch.seq_lens)
@@ -192,16 +197,8 @@ class TpModelWorkerClient:
192
197
  ) % self.future_token_ids_limit
193
198
  return None, future_next_token_ids
194
199
 
195
- def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
196
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
197
- logits_output = self.model_runner.forward(forward_batch)
198
- embeddings = logits_output.embeddings
199
- return embeddings
200
-
201
200
  def update_weights(self, recv_req: UpdateWeightReqInput):
202
- success, message = self.model_runner.update_weights(
203
- recv_req.model_path, recv_req.load_format
204
- )
201
+ success, message = self.worker.update_weights(recv_req)
205
202
  return success, message
206
203
 
207
204
  def __delete__(self):
@@ -90,6 +90,8 @@ def set_torch_compile_config():
90
90
 
91
91
  # FIXME: tmp workaround
92
92
  torch._dynamo.config.accumulated_cache_size_limit = 1024
93
+ if hasattr(torch._dynamo.config, "cache_size_limit"):
94
+ torch._dynamo.config.cache_size_limit = 1024
93
95
 
94
96
 
95
97
  @maybe_torch_compile(dynamic=True)
@@ -111,6 +113,8 @@ class CudaGraphRunner:
111
113
  self.use_torch_compile = model_runner.server_args.enable_torch_compile
112
114
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
113
115
  self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
116
+ self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
117
+ self.tp_size = self.model_runner.tp_size
114
118
 
115
119
  # Batch sizes to capture
116
120
  if model_runner.server_args.disable_cuda_graph_padding:
@@ -165,6 +169,15 @@ class CudaGraphRunner:
165
169
  else:
166
170
  self.encoder_lens = None
167
171
 
172
+ if self.enable_dp_attention:
173
+ self.gathered_buffer = torch.zeros(
174
+ (
175
+ self.max_bs * self.tp_size,
176
+ self.model_runner.model_config.hidden_size,
177
+ ),
178
+ dtype=self.model_runner.dtype,
179
+ )
180
+
168
181
  # Capture
169
182
  try:
170
183
  with self.model_capture_mode():
@@ -190,11 +203,21 @@ class CudaGraphRunner:
190
203
  self.model_runner.model.capture_mode = False
191
204
 
192
205
  def can_run(self, forward_batch: ForwardBatch):
193
- is_bs_supported = (
194
- forward_batch.batch_size in self.graphs
195
- if self.disable_padding
196
- else forward_batch.batch_size <= self.max_bs
197
- )
206
+ if self.enable_dp_attention:
207
+ min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
208
+ forward_batch.global_num_tokens
209
+ )
210
+ is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
211
+ (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
212
+ if self.disable_padding
213
+ else max_num_tokens <= self.max_bs
214
+ )
215
+ else:
216
+ is_bs_supported = (
217
+ forward_batch.batch_size in self.graphs
218
+ if self.disable_padding
219
+ else forward_batch.batch_size <= self.max_bs
220
+ )
198
221
 
199
222
  # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
200
223
  # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
@@ -239,6 +262,13 @@ class CudaGraphRunner:
239
262
  seq_lens_sum = seq_lens.sum().item()
240
263
  mrope_positions = self.mrope_positions[:, :bs]
241
264
 
265
+ if self.enable_dp_attention:
266
+ global_num_tokens = [bs] * self.tp_size
267
+ gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
268
+ else:
269
+ global_num_tokens = None
270
+ gathered_buffer = None
271
+
242
272
  # Attention backend
243
273
  self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
244
274
  bs,
@@ -265,6 +295,8 @@ class CudaGraphRunner:
265
295
  top_logprobs_nums=[0] * bs,
266
296
  positions=clamp_position(seq_lens),
267
297
  mrope_positions=mrope_positions,
298
+ global_num_tokens=global_num_tokens,
299
+ gathered_buffer=gathered_buffer,
268
300
  )
269
301
  logits_output = forward(input_ids, forward_batch.positions, forward_batch)
270
302
  return logits_output.next_token_logits
@@ -295,7 +327,12 @@ class CudaGraphRunner:
295
327
  raw_bs = forward_batch.batch_size
296
328
 
297
329
  # Pad
298
- index = bisect.bisect_left(self.capture_bs, raw_bs)
330
+ if self.enable_dp_attention:
331
+ index = bisect.bisect_left(
332
+ self.capture_bs, max(forward_batch.global_num_tokens)
333
+ )
334
+ else:
335
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
299
336
  bs = self.capture_bs[index]
300
337
  if bs != raw_bs:
301
338
  self.seq_lens.fill_(1)
@@ -36,6 +36,8 @@ from enum import IntEnum, auto
36
36
  from typing import TYPE_CHECKING, List, Optional
37
37
 
38
38
  import torch
39
+ import triton
40
+ import triton.language as tl
39
41
 
40
42
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
43
 
@@ -50,12 +52,18 @@ if TYPE_CHECKING:
50
52
  class ForwardMode(IntEnum):
51
53
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
52
54
  PREFILL = auto()
53
- # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
55
+ # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
54
56
  EXTEND = auto()
55
57
  # Decode one token.
56
58
  DECODE = auto()
57
- # Contains both EXTEND and DECODE.
59
+ # Contains both EXTEND and DECODE when doing chunked prefill.
58
60
  MIXED = auto()
61
+ # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated.
62
+ IDLE = auto()
63
+
64
+ # A dummy first batch to start the pipeline for overlap scheduler.
65
+ # It is now used for triggering the sampling_info_done event for the first prefill batch.
66
+ DUMMY_FIRST = auto()
59
67
 
60
68
  def is_prefill(self):
61
69
  return self == ForwardMode.PREFILL
@@ -69,6 +77,12 @@ class ForwardMode(IntEnum):
69
77
  def is_mixed(self):
70
78
  return self == ForwardMode.MIXED
71
79
 
80
+ def is_idle(self):
81
+ return self == ForwardMode.IDLE
82
+
83
+ def is_dummy_first(self):
84
+ return self == ForwardMode.DUMMY_FIRST
85
+
72
86
 
73
87
  @dataclass
74
88
  class ForwardBatch:
@@ -102,6 +116,7 @@ class ForwardBatch:
102
116
  extend_seq_lens: Optional[torch.Tensor] = None
103
117
  extend_prefix_lens: Optional[torch.Tensor] = None
104
118
  extend_start_loc: Optional[torch.Tensor] = None
119
+ extend_prefix_lens_cpu: Optional[List[int]] = None
105
120
  extend_seq_lens_cpu: Optional[List[int]] = None
106
121
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
107
122
 
@@ -128,6 +143,11 @@ class ForwardBatch:
128
143
  # For Qwen2-VL
129
144
  mrope_positions: torch.Tensor = None
130
145
 
146
+ # For DP attention
147
+ global_num_tokens: Optional[List[int]] = None
148
+ gathered_buffer: Optional[torch.Tensor] = None
149
+ can_run_dp_cuda_graph: bool = False
150
+
131
151
  def compute_mrope_positions(
132
152
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
133
153
  ):
@@ -209,31 +229,36 @@ class ForwardBatch:
209
229
  seq_lens_sum=batch.seq_lens_sum,
210
230
  return_logprob=batch.return_logprob,
211
231
  top_logprobs_nums=batch.top_logprobs_nums,
232
+ global_num_tokens=batch.global_num_tokens,
233
+ can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
212
234
  lora_paths=batch.lora_paths,
213
235
  sampling_info=batch.sampling_info,
214
236
  )
215
237
 
238
+ if ret.global_num_tokens is not None:
239
+ max_len = max(ret.global_num_tokens)
240
+ ret.gathered_buffer = torch.zeros(
241
+ (max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
242
+ dtype=model_runner.dtype,
243
+ device=device,
244
+ )
245
+
246
+ if ret.forward_mode.is_idle():
247
+ return ret
248
+
216
249
  # Init position information
217
250
  if not ret.forward_mode.is_decode():
218
- ret.positions = torch.concat(
219
- [
220
- torch.arange(prefix_len, prefix_len + extend_len, device=device)
221
- for prefix_len, extend_len in zip(
222
- batch.extend_prefix_lens, batch.extend_seq_lens
223
- )
224
- ],
225
- axis=0,
226
- )
227
- ret.extend_num_tokens = batch.extend_num_tokens
228
251
  ret.extend_seq_lens = torch.tensor(
229
252
  batch.extend_seq_lens, dtype=torch.int32
230
253
  ).to(device, non_blocking=True)
231
-
232
254
  ret.extend_prefix_lens = torch.tensor(
233
255
  batch.extend_prefix_lens, dtype=torch.int32
234
256
  ).to(device, non_blocking=True)
235
- ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
236
- ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
257
+ ret.extend_num_tokens = batch.extend_num_tokens
258
+ ret.positions, ret.extend_start_loc = compute_position_triton(
259
+ ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
260
+ )
261
+ ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
237
262
  ret.extend_seq_lens_cpu = batch.extend_seq_lens
238
263
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
239
264
 
@@ -250,3 +275,72 @@ class ForwardBatch:
250
275
  model_runner.lora_manager.prepare_lora_batch(ret)
251
276
 
252
277
  return ret
278
+
279
+
280
+ def compute_position_triton(
281
+ extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
282
+ ):
283
+ """Compute positions. It is a fused version of `compute_position_torch`."""
284
+ batch_size = extend_seq_lens.shape[0]
285
+ positions = torch.empty(
286
+ extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
287
+ )
288
+ extend_start_loc = torch.empty(
289
+ batch_size, dtype=torch.int32, device=extend_seq_lens.device
290
+ )
291
+
292
+ # Launch kernel
293
+ compute_position_kernel[(batch_size,)](
294
+ positions,
295
+ extend_start_loc,
296
+ extend_prefix_lens,
297
+ extend_seq_lens,
298
+ )
299
+
300
+ return positions, extend_start_loc
301
+
302
+
303
+ @triton.jit
304
+ def compute_position_kernel(
305
+ positions,
306
+ extend_start_loc,
307
+ extend_prefix_lens,
308
+ extend_seq_lens,
309
+ ):
310
+ BLOCK_SIZE: tl.constexpr = 512
311
+ pid = tl.program_id(0)
312
+
313
+ prefix_len = tl.load(extend_prefix_lens + pid)
314
+ seq_len = tl.load(extend_seq_lens + pid)
315
+
316
+ # TODO: optimize this?
317
+ cumsum_start = 0
318
+ for i in range(pid):
319
+ cumsum_start += tl.load(extend_seq_lens + i)
320
+
321
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
322
+ for i in range(num_loop):
323
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
324
+ tl.store(
325
+ positions + cumsum_start + offset,
326
+ prefix_len + offset,
327
+ mask=offset < seq_len,
328
+ )
329
+ tl.store(extend_start_loc + pid, cumsum_start)
330
+
331
+
332
+ def compute_position_torch(
333
+ extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
334
+ ):
335
+ positions = torch.concat(
336
+ [
337
+ torch.arange(
338
+ prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
339
+ )
340
+ for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
341
+ ],
342
+ axis=0,
343
+ )
344
+ extend_start_loc = torch.zeros_like(extend_seq_lens)
345
+ extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
346
+ return positions.to(torch.int64), extend_start_loc