sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """DetokenizerManager is a process that detokenizes the token ids."""
17
15
 
18
16
  import dataclasses
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  The definition of objects transfered between different
18
16
  processes (TokenizerManager, DetokenizerManager, Controller).
@@ -21,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
21
19
  import uuid
22
20
  from dataclasses import dataclass
23
21
  from enum import Enum
24
- from typing import Dict, List, Optional, Union
22
+ from typing import Dict, List, Optional, Tuple, Union
25
23
 
26
24
  from sglang.srt.managers.schedule_batch import BaseFinishReason
27
25
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -31,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
31
29
  class GenerateReqInput:
32
30
  # The input prompt. It can be a single prompt or a batch of prompts.
33
31
  text: Optional[Union[List[str], str]] = None
34
- # The token ids for text; one can either specify text or input_ids.
32
+ # The token ids for text; one can specify either text or input_ids
35
33
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
34
+ # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
35
+ input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
36
36
  # The image input. It can be a file name, a url, or base64 encoded string.
37
37
  # See also python/sglang/srt/utils.py:load_image.
38
38
  image_data: Optional[Union[List[str], str]] = None
@@ -56,11 +56,22 @@ class GenerateReqInput:
56
56
  # LoRA related
57
57
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
58
58
 
59
+ # Session id info for continual prompting
60
+ session: Optional[
61
+ Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
62
+ ] = None
63
+
59
64
  def normalize_batch_and_arguments(self):
60
- if (self.text is None and self.input_ids is None) or (
61
- self.text is not None and self.input_ids is not None
65
+ if (
66
+ self.text is None and self.input_ids is None and self.input_embeds is None
67
+ ) or (
68
+ self.text is not None
69
+ and self.input_ids is not None
70
+ and self.input_embeds is not None
62
71
  ):
63
- raise ValueError("Either text or input_ids should be provided.")
72
+ raise ValueError(
73
+ "Either text, input_ids or input_embeds should be provided."
74
+ )
64
75
 
65
76
  # Derive the batch size
66
77
  if self.text is not None:
@@ -70,13 +81,21 @@ class GenerateReqInput:
70
81
  else:
71
82
  self.is_single = False
72
83
  self.batch_size = len(self.text)
73
- else:
84
+ self.input_embeds = None
85
+ elif self.input_ids is not None:
74
86
  if isinstance(self.input_ids[0], int):
75
87
  self.is_single = True
76
88
  self.batch_size = 1
77
89
  else:
78
90
  self.is_single = False
79
91
  self.batch_size = len(self.input_ids)
92
+ self.input_embeds = None
93
+ else:
94
+ if isinstance(self.input_embeds[0][0], float):
95
+ self.is_single = True
96
+ self.batch_size = 1
97
+ else:
98
+ self.batch_size = len(self.input_embeds)
80
99
 
81
100
  # Handle parallel sampling
82
101
  # When parallel sampling is used, we always treat the input as a batch.
@@ -199,6 +218,12 @@ class TokenizedGenerateReqInput:
199
218
 
200
219
  # LoRA related
201
220
  lora_path: Optional[str] = None # None means just use the base model
221
+ # The input embeds
222
+ input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
223
+
224
+ # Session id info for continual prompting
225
+ session_id: Optional[str] = None
226
+ session_rid: Optional[str] = None
202
227
 
203
228
 
204
229
  @dataclass
@@ -211,6 +236,8 @@ class EmbeddingReqInput:
211
236
  rid: Optional[Union[List[str], str]] = None
212
237
  # Dummy sampling params for compatibility
213
238
  sampling_params: Union[List[Dict], Dict] = None
239
+ # Dummy input embeds for compatibility
240
+ input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
214
241
 
215
242
  def normalize_batch_and_arguments(self):
216
243
  if (self.text is None and self.input_ids is None) or (
@@ -357,3 +384,18 @@ class GetMemPoolSizeReq:
357
384
  @dataclass
358
385
  class GetMemPoolSizeReqOutput:
359
386
  size: int
387
+
388
+
389
+ @dataclass
390
+ class OpenSessionReqInput:
391
+ capacity_of_str_len: int
392
+
393
+
394
+ @dataclass
395
+ class CloseSessionReqInput:
396
+ session_id: str
397
+
398
+
399
+ @dataclass
400
+ class OpenSessionReqOutput:
401
+ session_id: str
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  Store information about requests and batches.
18
16
 
@@ -34,6 +32,8 @@ import logging
34
32
  from typing import List, Optional, Tuple, Union
35
33
 
36
34
  import torch
35
+ import triton
36
+ import triton.language as tl
37
37
 
38
38
  from sglang.global_config import global_config
39
39
  from sglang.srt.configs.model_config import ModelConfig
@@ -55,7 +55,8 @@ global_server_args_dict = {
55
55
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
56
56
  "disable_mla": ServerArgs.disable_mla,
57
57
  "torchao_config": ServerArgs.torchao_config,
58
- "disable_nan_detection": ServerArgs.disable_nan_detection,
58
+ "enable_nan_detection": ServerArgs.enable_nan_detection,
59
+ "enable_dp_attention": ServerArgs.enable_dp_attention,
59
60
  }
60
61
 
61
62
 
@@ -133,6 +134,7 @@ class ImageInputs:
133
134
  image_embeds: Optional[List[torch.Tensor]] = None
134
135
  aspect_ratio_ids: Optional[List[torch.Tensor]] = None
135
136
  aspect_ratio_mask: Optional[List[torch.Tensor]] = None
137
+
136
138
  # QWen2-VL related
137
139
  image_grid_thws: List[Tuple[int, int, int]] = None
138
140
  mrope_position_delta: Optional[torch.Tensor] = None
@@ -176,6 +178,8 @@ class Req:
176
178
  origin_input_ids: Tuple[int],
177
179
  sampling_params: SamplingParams,
178
180
  lora_path: Optional[str] = None,
181
+ input_embeds: Optional[List[List[float]]] = None,
182
+ session_id: Optional[str] = None,
179
183
  ):
180
184
  # Input and output info
181
185
  self.rid = rid
@@ -184,11 +188,13 @@ class Req:
184
188
  self.origin_input_ids = origin_input_ids
185
189
  self.output_ids = [] # Each decode stage's output ids
186
190
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
191
+ self.session_id = session_id
187
192
 
188
193
  self.sampling_params = sampling_params
189
194
  self.lora_path = lora_path
195
+ self.input_embeds = input_embeds
190
196
 
191
- # Memory info
197
+ # Memory pool info
192
198
  self.req_pool_idx = None
193
199
 
194
200
  # Check finish
@@ -425,7 +431,7 @@ bid = 0
425
431
 
426
432
  @dataclasses.dataclass
427
433
  class ScheduleBatch:
428
- """Store all inforamtion of a batch."""
434
+ """Store all inforamtion of a batch on the scheduler."""
429
435
 
430
436
  # Request, memory pool, and cache
431
437
  reqs: List[Req]
@@ -433,14 +439,18 @@ class ScheduleBatch:
433
439
  token_to_kv_pool: BaseTokenToKVPool = None
434
440
  tree_cache: BasePrefixCache = None
435
441
 
436
- # For utility
442
+ # Batch configs
437
443
  model_config: ModelConfig = None
438
-
439
444
  forward_mode: ForwardMode = None
445
+ enable_overlap: bool = False
446
+
447
+ # Sampling info
440
448
  sampling_info: SamplingBatchInfo = None
449
+ next_batch_sampling_info: SamplingBatchInfo = None
441
450
 
442
451
  # Batched arguments to model runner
443
452
  input_ids: torch.Tensor = None
453
+ input_embeds: torch.Tensor = None
444
454
  req_pool_indices: torch.Tensor = None
445
455
  seq_lens: torch.Tensor = None
446
456
  # The output locations of the KV cache
@@ -450,6 +460,10 @@ class ScheduleBatch:
450
460
  # The sum of all sequence lengths
451
461
  seq_lens_sum: int = None
452
462
 
463
+ # For DP attention
464
+ global_num_tokens: Optional[List[int]] = None
465
+ can_run_dp_cuda_graph: bool = False
466
+
453
467
  # For processing logprobs
454
468
  return_logprob: bool = False
455
469
  top_logprobs_nums: Optional[List[int]] = None
@@ -459,6 +473,7 @@ class ScheduleBatch:
459
473
  extend_lens: List[int] = None
460
474
  extend_num_tokens: int = None
461
475
  decoding_reqs: List[Req] = None
476
+ extend_logprob_start_lens: List[int] = None
462
477
 
463
478
  # For encoder-decoder
464
479
  encoder_cached: Optional[List[bool]] = None
@@ -479,10 +494,11 @@ class ScheduleBatch:
479
494
  def init_new(
480
495
  cls,
481
496
  reqs: List[Req],
482
- req_to_token_pool,
483
- token_to_kv_pool,
484
- tree_cache,
485
- model_config,
497
+ req_to_token_pool: ReqToTokenPool,
498
+ token_to_kv_pool: ReqToTokenPool,
499
+ tree_cache: BasePrefixCache,
500
+ model_config: ModelConfig,
501
+ enable_overlap: bool,
486
502
  ):
487
503
  return cls(
488
504
  reqs=reqs,
@@ -490,6 +506,7 @@ class ScheduleBatch:
490
506
  token_to_kv_pool=token_to_kv_pool,
491
507
  tree_cache=tree_cache,
492
508
  model_config=model_config,
509
+ enable_overlap=enable_overlap,
493
510
  return_logprob=any(req.return_logprob for req in reqs),
494
511
  has_stream=any(req.stream for req in reqs),
495
512
  has_grammar=any(req.grammar for req in reqs),
@@ -502,7 +519,7 @@ class ScheduleBatch:
502
519
  def is_empty(self):
503
520
  return len(self.reqs) == 0
504
521
 
505
- def alloc_req_slots(self, num_reqs):
522
+ def alloc_req_slots(self, num_reqs: int):
506
523
  req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
507
524
  if req_pool_indices is None:
508
525
  raise RuntimeError(
@@ -588,14 +605,14 @@ class ScheduleBatch:
588
605
  )
589
606
 
590
607
  if not decoder_out_cache_loc:
591
- self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
608
+ self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
592
609
  self.device, non_blocking=True
593
610
  )
594
611
  else:
595
612
  self.out_cache_loc = torch.cat(decoder_out_cache_loc)
596
613
 
597
614
  if not encoder_out_cache_loc:
598
- self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
615
+ self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
599
616
  self.device, non_blocking=True
600
617
  )
601
618
  else:
@@ -611,11 +628,14 @@ class ScheduleBatch:
611
628
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
612
629
  extend_num_tokens = sum(len(ids) for ids in input_ids)
613
630
  seq_lens = []
631
+ pre_lens = []
614
632
 
615
633
  # Allocate memory
616
634
  req_pool_indices = self.alloc_req_slots(bs)
617
635
  out_cache_loc = self.alloc_token_slots(extend_num_tokens)
618
636
 
637
+ input_embeds = []
638
+
619
639
  pt = 0
620
640
  for i, req in enumerate(reqs):
621
641
  already_computed = (
@@ -634,10 +654,11 @@ class ScheduleBatch:
634
654
  self.req_to_token_pool.write(
635
655
  (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
636
656
  )
637
- self.req_to_token_pool.write(
638
- (req.req_pool_idx, slice(pre_len, seq_len)),
639
- out_cache_loc[pt : pt + req.extend_input_len],
640
- )
657
+
658
+ # If input_embeds are available, store them
659
+ if req.input_embeds is not None:
660
+ # If req.input_embeds is already a list, append its content directly
661
+ input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
641
662
 
642
663
  # Compute the relative logprob_start_len in an extend batch
643
664
  if req.logprob_start_len >= pre_len:
@@ -648,8 +669,8 @@ class ScheduleBatch:
648
669
  extend_logprob_start_len = req.extend_input_len - 1
649
670
 
650
671
  req.extend_logprob_start_len = extend_logprob_start_len
651
- pt += req.extend_input_len
652
672
  req.is_retracted = False
673
+ pre_lens.append(pre_len)
653
674
 
654
675
  # Set fields
655
676
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
@@ -661,6 +682,11 @@ class ScheduleBatch:
661
682
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
662
683
  self.device, non_blocking=True
663
684
  )
685
+ self.input_embeds = (
686
+ torch.tensor(input_embeds).to(self.device, non_blocking=True)
687
+ if input_embeds
688
+ else None
689
+ )
664
690
 
665
691
  self.out_cache_loc = out_cache_loc
666
692
 
@@ -672,13 +698,37 @@ class ScheduleBatch:
672
698
  self.extend_lens = [r.extend_input_len for r in reqs]
673
699
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
674
700
 
701
+ # Write to req_to_token_pool
702
+ pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
703
+ self.device, non_blocking=True
704
+ )
705
+ extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
706
+ self.device, non_blocking=True
707
+ )
708
+ write_req_to_token_pool_triton[(bs,)](
709
+ self.req_to_token_pool.req_to_token,
710
+ self.req_pool_indices,
711
+ pre_lens,
712
+ self.seq_lens,
713
+ extend_lens,
714
+ self.out_cache_loc,
715
+ self.req_to_token_pool.req_to_token.shape[1],
716
+ )
717
+ # The triton kernel is equivalent to the following python code.
718
+ # self.req_to_token_pool.write(
719
+ # (req.req_pool_idx, slice(pre_len, seq_len)),
720
+ # out_cache_loc[pt : pt + req.extend_input_len],
721
+ # )
722
+ # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
723
+
675
724
  if self.model_config.is_encoder_decoder:
676
725
  self.prepare_encoder_info_extend(input_ids, seq_lens)
677
726
 
727
+ # Build sampling info
678
728
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
679
729
  self,
680
730
  self.model_config.vocab_size,
681
- global_server_args_dict["disable_penalizer"],
731
+ enable_overlap_schedule=self.enable_overlap,
682
732
  )
683
733
 
684
734
  def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -695,16 +745,20 @@ class ScheduleBatch:
695
745
  self.merge_batch(running_batch)
696
746
  self.input_ids = input_ids
697
747
  self.out_cache_loc = out_cache_loc
698
- self.extend_num_tokens += running_bs
748
+
749
+ # For overlap scheduler, the output_ids has one step delay
750
+ delta = 0 if self.enable_overlap else -1
699
751
 
700
752
  # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
701
753
  self.prefix_lens.extend(
702
754
  [
703
- len(r.origin_input_ids) + len(r.output_ids) - 1
755
+ len(r.origin_input_ids) + len(r.output_ids) + delta
704
756
  for r in running_batch.reqs
705
757
  ]
706
758
  )
707
759
  self.extend_lens.extend([1] * running_bs)
760
+ self.extend_num_tokens += running_bs
761
+ # TODO (lianmin): Revisit this. It should be seq_len - 1
708
762
  self.extend_logprob_start_lens.extend([0] * running_bs)
709
763
 
710
764
  def check_decode_mem(self):
@@ -720,6 +774,7 @@ class ScheduleBatch:
720
774
  return False
721
775
 
722
776
  def retract_decode(self):
777
+ """Retract the decoding requests when there is not enough memory."""
723
778
  sorted_indices = [i for i in range(len(self.reqs))]
724
779
 
725
780
  # TODO(lsyin): improve retraction policy for radix cache
@@ -858,15 +913,21 @@ class ScheduleBatch:
858
913
  # Reset the encoder cached status
859
914
  self.encoder_cached = [True] * len(self.reqs)
860
915
 
861
- def prepare_for_decode(self, enable_overlap: bool = False):
916
+ def prepare_for_idle(self):
917
+ self.forward_mode = ForwardMode.IDLE
918
+ self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
919
+ self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
920
+ self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
921
+ self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
922
+ self.seq_lens_sum = 0
923
+ self.extend_num_tokens = 0
924
+
925
+ def prepare_for_decode(self):
862
926
  self.forward_mode = ForwardMode.DECODE
863
927
 
864
928
  self.input_ids = self.output_ids
865
929
  self.output_ids = None
866
- if self.sampling_info.penalizer_orchestrator:
867
- self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
868
- self.input_ids
869
- )
930
+ self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
870
931
 
871
932
  # Alloc mem
872
933
  bs = len(self.reqs)
@@ -878,7 +939,7 @@ class ScheduleBatch:
878
939
  else:
879
940
  locs = self.seq_lens
880
941
 
881
- if enable_overlap:
942
+ if self.enable_overlap:
882
943
  # Do not use in-place operations in the overlap mode
883
944
  self.req_to_token_pool.write(
884
945
  (self.req_pool_indices, locs), self.out_cache_loc
@@ -969,17 +1030,18 @@ class ScheduleBatch:
969
1030
  self.has_grammar = self.has_grammar or other.has_grammar
970
1031
 
971
1032
  def get_model_worker_batch(self):
972
- if self.forward_mode.is_decode():
1033
+ if self.forward_mode.is_decode() or self.forward_mode.is_idle():
973
1034
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
974
1035
  else:
975
1036
  extend_seq_lens = self.extend_lens
976
1037
  extend_prefix_lens = self.prefix_lens
977
1038
  extend_logprob_start_lens = self.extend_logprob_start_lens
978
1039
 
979
- if self.has_grammar:
980
- self.sampling_info.grammars = [req.grammar for req in self.reqs]
981
- else:
982
- self.sampling_info.grammars = None
1040
+ if self.sampling_info:
1041
+ if self.has_grammar:
1042
+ self.sampling_info.grammars = [req.grammar for req in self.reqs]
1043
+ else:
1044
+ self.sampling_info.grammars = None
983
1045
 
984
1046
  global bid
985
1047
  bid += 1
@@ -995,6 +1057,8 @@ class ScheduleBatch:
995
1057
  req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
996
1058
  return_logprob=self.return_logprob,
997
1059
  top_logprobs_nums=self.top_logprobs_nums,
1060
+ global_num_tokens=self.global_num_tokens,
1061
+ can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
998
1062
  extend_num_tokens=self.extend_num_tokens,
999
1063
  extend_seq_lens=extend_seq_lens,
1000
1064
  extend_prefix_lens=extend_prefix_lens,
@@ -1006,6 +1070,7 @@ class ScheduleBatch:
1006
1070
  encoder_out_cache_loc=self.encoder_out_cache_loc,
1007
1071
  lora_paths=[req.lora_path for req in self.reqs],
1008
1072
  sampling_info=self.sampling_info,
1073
+ input_embeds=self.input_embeds,
1009
1074
  )
1010
1075
 
1011
1076
  def copy(self):
@@ -1051,6 +1116,10 @@ class ModelWorkerBatch:
1051
1116
  return_logprob: bool
1052
1117
  top_logprobs_nums: Optional[List[int]]
1053
1118
 
1119
+ # For DP attention
1120
+ global_num_tokens: Optional[List[int]]
1121
+ can_run_dp_cuda_graph: bool
1122
+
1054
1123
  # For extend
1055
1124
  extend_num_tokens: Optional[int]
1056
1125
  extend_seq_lens: Optional[List[int]]
@@ -1072,16 +1141,42 @@ class ModelWorkerBatch:
1072
1141
  # Sampling info
1073
1142
  sampling_info: SamplingBatchInfo
1074
1143
 
1075
- def copy(self):
1076
- return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
1077
-
1078
- def to(self, device: str):
1079
- self.input_ids = self.input_ids.to(device, non_blocking=True)
1080
- self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
1081
- self.seq_lens = self.seq_lens.to(device, non_blocking=True)
1082
- self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
1083
- self.req_to_token_pool_records = [
1084
- (x, y.to(device, non_blocking=True))
1085
- for x, y in self.req_to_token_pool_records
1086
- ]
1087
- self.sampling_info.to(device)
1144
+ # The input Embeds
1145
+ input_embeds: Optional[torch.tensor] = None
1146
+
1147
+
1148
+ @triton.jit
1149
+ def write_req_to_token_pool_triton(
1150
+ req_to_token_ptr, # [max_batch, max_context_len]
1151
+ req_pool_indices,
1152
+ pre_lens,
1153
+ seq_lens,
1154
+ extend_lens,
1155
+ out_cache_loc,
1156
+ req_to_token_ptr_stride: tl.constexpr,
1157
+ ):
1158
+ BLOCK_SIZE: tl.constexpr = 512
1159
+ pid = tl.program_id(0)
1160
+
1161
+ req_pool_index = tl.load(req_pool_indices + pid)
1162
+ pre_len = tl.load(pre_lens + pid)
1163
+ seq_len = tl.load(seq_lens + pid)
1164
+
1165
+ # TODO: optimize this?
1166
+ cumsum_start = 0
1167
+ for i in range(pid):
1168
+ cumsum_start += tl.load(extend_lens + i)
1169
+
1170
+ num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
1171
+ for i in range(num_loop):
1172
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1173
+ mask = offset < (seq_len - pre_len)
1174
+ value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
1175
+ tl.store(
1176
+ req_to_token_ptr
1177
+ + req_pool_index * req_to_token_ptr_stride
1178
+ + offset
1179
+ + pre_len,
1180
+ value,
1181
+ mask=mask,
1182
+ )
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Request scheduler policy"""
17
15
 
18
16
  import os
@@ -302,7 +300,11 @@ class PrefillAdder:
302
300
  if (
303
301
  self.rem_chunk_tokens is None
304
302
  or input_tokens <= self.rem_chunk_tokens
305
- or (req.return_logprob and req.normalized_prompt_logprob is None)
303
+ or (
304
+ req.return_logprob
305
+ and req.normalized_prompt_logprob is None
306
+ and req.logprob_start_len != len(req.origin_input_ids) - 1
307
+ )
306
308
  ):
307
309
  # Non-chunked prefill
308
310
  self.can_run_list.append(req)