sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1021 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """A scheduler that manages a tensor parallel GPU worker."""
17
+
18
+ import json
19
+ import logging
20
+ import multiprocessing
21
+ import os
22
+ import time
23
+ import warnings
24
+ from typing import List, Optional, Union
25
+
26
+ import torch
27
+ import zmq
28
+
29
+ from sglang.global_config import global_config
30
+ from sglang.srt.configs.model_config import ModelConfig
31
+ from sglang.srt.constrained.fsm_cache import FSMCache
32
+ from sglang.srt.constrained.jump_forward import JumpForwardCache
33
+ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
35
+ from sglang.srt.managers.io_struct import (
36
+ AbortReq,
37
+ BatchEmbeddingOut,
38
+ BatchTokenIDOut,
39
+ FlushCacheReq,
40
+ TokenizedEmbeddingReqInput,
41
+ TokenizedGenerateReqInput,
42
+ TokenizedRewardReqInput,
43
+ UpdateWeightReqInput,
44
+ UpdateWeightReqOutput,
45
+ )
46
+ from sglang.srt.managers.schedule_batch import (
47
+ FINISH_ABORT,
48
+ BaseFinishReason,
49
+ ImageInputs,
50
+ Req,
51
+ ScheduleBatch,
52
+ )
53
+ from sglang.srt.managers.schedule_policy import (
54
+ AddReqResult,
55
+ PrefillAdder,
56
+ SchedulePolicy,
57
+ )
58
+ from sglang.srt.managers.tp_worker import TpModelWorker
59
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache
60
+ from sglang.srt.mem_cache.radix_cache import RadixCache
61
+ from sglang.srt.server_args import PortArgs, ServerArgs
62
+ from sglang.srt.utils import (
63
+ broadcast_pyobj,
64
+ configure_logger,
65
+ is_generation_model,
66
+ is_multimodal_model,
67
+ kill_parent_process,
68
+ pytorch_profile,
69
+ set_random_seed,
70
+ suppress_other_loggers,
71
+ )
72
+ from sglang.utils import get_exception_traceback
73
+
74
+ logger = logging.getLogger(__name__)
75
+
76
+ # Crash on warning if we are running CI tests
77
+ crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
78
+
79
+
80
+ class Scheduler:
81
+ """A scheduler that manages a tensor parallel GPU worker."""
82
+
83
+ def __init__(
84
+ self,
85
+ server_args: ServerArgs,
86
+ port_args: PortArgs,
87
+ gpu_id: int,
88
+ tp_rank: int,
89
+ ):
90
+ # Parse args
91
+ self.server_args = server_args
92
+ self.tp_rank = tp_rank
93
+ self.tp_size = server_args.tp_size
94
+ self.schedule_policy = server_args.schedule_policy
95
+ self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
96
+ self.lora_paths = server_args.lora_paths
97
+ self.max_loras_per_batch = server_args.max_loras_per_batch
98
+
99
+ # Init inter-process communication
100
+ context = zmq.Context(2)
101
+
102
+ if self.tp_rank == 0:
103
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
104
+ self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
105
+
106
+ self.send_to_detokenizer = context.socket(zmq.PUSH)
107
+ self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
108
+ else:
109
+ self.recv_from_tokenizer = self.send_to_detokenizer = None
110
+
111
+ # Init tokenizer
112
+ self.model_config = ModelConfig(
113
+ server_args.model_path,
114
+ server_args.trust_remote_code,
115
+ context_length=server_args.context_length,
116
+ model_override_args=json.loads(server_args.json_model_override_args),
117
+ )
118
+
119
+ if server_args.skip_tokenizer_init:
120
+ self.tokenizer = self.processor = None
121
+ else:
122
+ if is_multimodal_model(self.model_config.hf_config.architectures):
123
+ self.processor = get_processor(
124
+ server_args.tokenizer_path,
125
+ tokenizer_mode=server_args.tokenizer_mode,
126
+ trust_remote_code=server_args.trust_remote_code,
127
+ )
128
+ self.tokenizer = self.processor.tokenizer
129
+ else:
130
+ self.tokenizer = get_tokenizer(
131
+ server_args.tokenizer_path,
132
+ tokenizer_mode=server_args.tokenizer_mode,
133
+ trust_remote_code=server_args.trust_remote_code,
134
+ )
135
+ self.is_generation = is_generation_model(
136
+ self.model_config.hf_config.architectures, self.server_args.is_embedding
137
+ )
138
+
139
+ # Launch a tensor parallel worker
140
+ self.tp_worker = TpModelWorker(
141
+ gpu_id=gpu_id,
142
+ tp_rank=tp_rank,
143
+ server_args=server_args,
144
+ nccl_port=port_args.nccl_ports[0],
145
+ )
146
+ self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
147
+
148
+ # Get token and memory info from the model worker
149
+ (
150
+ self.max_total_num_tokens,
151
+ self.max_prefill_tokens,
152
+ self.max_running_requests,
153
+ self.max_req_input_len,
154
+ self.random_seed,
155
+ ) = self.tp_worker.get_token_and_memory_info()
156
+ set_random_seed(self.random_seed)
157
+ self.pad_input_ids_func = getattr(
158
+ self.tp_worker.model_runner.model, "pad_input_ids", None
159
+ )
160
+
161
+ # Print debug info
162
+ logger.info(
163
+ f"max_total_num_tokens={self.max_total_num_tokens}, "
164
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
165
+ f"max_running_requests={self.max_running_requests}, "
166
+ f"context_len={self.model_config.context_len}"
167
+ )
168
+
169
+ # Init cache
170
+ self.req_to_token_pool = self.tp_worker.model_runner.req_to_token_pool
171
+ self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool
172
+
173
+ if (
174
+ server_args.chunked_prefill_size is not None
175
+ and server_args.disable_radix_cache
176
+ ):
177
+ self.tree_cache = ChunkCache(
178
+ req_to_token_pool=self.req_to_token_pool,
179
+ token_to_kv_pool=self.token_to_kv_pool,
180
+ )
181
+ else:
182
+ self.tree_cache = RadixCache(
183
+ req_to_token_pool=self.req_to_token_pool,
184
+ token_to_kv_pool=self.token_to_kv_pool,
185
+ disable=server_args.disable_radix_cache,
186
+ )
187
+ self.tree_cache_metrics = {"total": 0, "hit": 0}
188
+ self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
189
+
190
+ # Init running status
191
+ self.waiting_queue: List[Req] = []
192
+ self.running_batch: ScheduleBatch = None
193
+ self.out_pyobjs = []
194
+ self.decode_forward_ct = 0
195
+ self.stream_interval = server_args.stream_interval
196
+ self.num_generated_tokens = 0
197
+ self.last_stats_tic = time.time()
198
+
199
+ # Init chunked prefill
200
+ self.chunked_prefill_size = server_args.chunked_prefill_size
201
+ self.current_inflight_req = None
202
+ self.is_mixed_chunk = (
203
+ self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
204
+ )
205
+
206
+ # Init the FSM cache for constrained generation
207
+ if not server_args.skip_tokenizer_init:
208
+ self.regex_fsm_cache = FSMCache(
209
+ server_args.tokenizer_path,
210
+ {
211
+ "tokenizer_mode": server_args.tokenizer_mode,
212
+ "trust_remote_code": server_args.trust_remote_code,
213
+ },
214
+ skip_tokenizer_init=server_args.skip_tokenizer_init,
215
+ constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
216
+ )
217
+ self.jump_forward_cache = JumpForwardCache()
218
+
219
+ # Init new token estimation
220
+ assert (
221
+ server_args.schedule_conservativeness >= 0
222
+ ), "Invalid schedule_conservativeness"
223
+ self.min_new_token_ratio = min(
224
+ global_config.base_min_new_token_ratio
225
+ * server_args.schedule_conservativeness,
226
+ 1.0,
227
+ )
228
+ self.new_token_ratio = self.min_new_token_ratio
229
+ self.new_token_ratio_decay = global_config.new_token_ratio_decay
230
+ self.batch_is_full = False
231
+
232
+ @torch.inference_mode()
233
+ def event_loop(self):
234
+ while True:
235
+ recv_reqs = self.recv_requests()
236
+ self.process_input_requests(recv_reqs)
237
+
238
+ self.run_step()
239
+
240
+ self.send_results()
241
+
242
+ def recv_requests(self):
243
+ if self.tp_rank == 0:
244
+ recv_reqs = []
245
+
246
+ while True:
247
+ try:
248
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
249
+ except zmq.ZMQError:
250
+ break
251
+ recv_reqs.append(recv_req)
252
+ else:
253
+ recv_reqs = None
254
+
255
+ if self.tp_size != 1:
256
+ recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
257
+ return recv_reqs
258
+
259
+ def process_input_requests(self, recv_reqs: List):
260
+ for recv_req in recv_reqs:
261
+ if isinstance(recv_req, TokenizedGenerateReqInput):
262
+ self.handle_generate_request(recv_req)
263
+ elif isinstance(
264
+ recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
265
+ ):
266
+ self.handle_embedding_request(recv_req)
267
+ elif isinstance(recv_req, FlushCacheReq):
268
+ self.flush_cache()
269
+ elif isinstance(recv_req, AbortReq):
270
+ self.abort_request(recv_req)
271
+ elif isinstance(recv_req, UpdateWeightReqInput):
272
+ success, message = self.update_weights(recv_req)
273
+ self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
274
+ else:
275
+ raise ValueError(f"Invalid request: {recv_req}")
276
+
277
+ def handle_generate_request(
278
+ self,
279
+ recv_req: TokenizedGenerateReqInput,
280
+ ):
281
+ req = Req(
282
+ recv_req.rid,
283
+ recv_req.input_text,
284
+ recv_req.input_ids,
285
+ recv_req.sampling_params,
286
+ lora_path=recv_req.lora_path,
287
+ )
288
+ req.tokenizer = self.tokenizer
289
+
290
+ # Image inputs
291
+ if recv_req.image_inputs is not None:
292
+ req.image_inputs = ImageInputs.from_dict(
293
+ recv_req.image_inputs, self.model_config.vocab_size
294
+ )
295
+ req.origin_input_ids = self.pad_input_ids_func(
296
+ req.origin_input_ids_unpadded, req.image_inputs
297
+ )
298
+
299
+ req.return_logprob = recv_req.return_logprob
300
+ req.top_logprobs_num = recv_req.top_logprobs_num
301
+ req.stream = recv_req.stream
302
+ req.logprob_start_len = recv_req.logprob_start_len
303
+
304
+ if req.logprob_start_len == -1:
305
+ # By default, only return the logprobs for output tokens
306
+ req.logprob_start_len = len(recv_req.input_ids) - 1
307
+
308
+ # Init regex FSM
309
+ if (
310
+ req.sampling_params.json_schema is not None
311
+ or req.sampling_params.regex is not None
312
+ ):
313
+ if req.sampling_params.json_schema is not None:
314
+ req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
315
+ ("json", req.sampling_params.json_schema)
316
+ )
317
+ elif req.sampling_params.regex is not None:
318
+ req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
319
+ ("regex", req.sampling_params.regex)
320
+ )
321
+ if not self.disable_regex_jump_forward:
322
+ req.jump_forward_map = self.jump_forward_cache.query(
323
+ computed_regex_string
324
+ )
325
+
326
+ # Truncate prompts that are too long
327
+ if len(req.origin_input_ids) >= self.max_req_input_len:
328
+ logger.warning(
329
+ "Request length is longer than the KV cache pool size or "
330
+ "the max context length. Truncated!!!"
331
+ )
332
+ req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
333
+ req.sampling_params.max_new_tokens = min(
334
+ (
335
+ req.sampling_params.max_new_tokens
336
+ if req.sampling_params.max_new_tokens is not None
337
+ else 1 << 30
338
+ ),
339
+ self.max_req_input_len - 1 - len(req.origin_input_ids),
340
+ )
341
+
342
+ self.waiting_queue.append(req)
343
+
344
+ def handle_embedding_request(
345
+ self,
346
+ recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
347
+ ):
348
+ req = Req(
349
+ recv_req.rid,
350
+ recv_req.input_text,
351
+ recv_req.input_ids,
352
+ recv_req.sampling_params,
353
+ )
354
+ req.tokenizer = self.tokenizer
355
+
356
+ # Truncate prompts that are too long
357
+ if len(req.origin_input_ids) >= self.max_req_input_len:
358
+ logger.warning(
359
+ "Request length is longer than the KV cache pool size or "
360
+ "the max context length. Truncated!!!"
361
+ )
362
+ req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
363
+
364
+ self.waiting_queue.append(req)
365
+
366
+ def send_results(self):
367
+ if self.tp_rank == 0:
368
+ for obj in self.out_pyobjs:
369
+ self.send_to_detokenizer.send_pyobj(obj)
370
+ self.out_pyobjs = []
371
+
372
+ def print_decode_stats(self):
373
+ num_used = self.max_total_num_tokens - (
374
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
375
+ )
376
+ throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
377
+ self.num_generated_tokens = 0
378
+ self.last_stats_tic = time.time()
379
+ logger.info(
380
+ f"Decode batch. "
381
+ f"#running-req: {len(self.running_batch.reqs)}, "
382
+ f"#token: {num_used}, "
383
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
384
+ f"gen throughput (token/s): {throughput:.2f}, "
385
+ f"#queue-req: {len(self.waiting_queue)}"
386
+ )
387
+
388
+ def check_memory(self):
389
+ available_size = (
390
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
391
+ )
392
+ if available_size != self.max_total_num_tokens:
393
+ warnings.warn(
394
+ "Warning: "
395
+ f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
396
+ "KV cache pool leak detected!"
397
+ )
398
+ exit(1) if crash_on_warning else None
399
+
400
+ if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
401
+ warnings.warn(
402
+ "Warning: "
403
+ f"available req slots={len(self.req_to_token_pool.free_slots)}, "
404
+ f"total slots={self.req_to_token_pool.size}\n"
405
+ "Memory pool leak detected!"
406
+ )
407
+ exit(1) if crash_on_warning else None
408
+
409
+ def run_step(self):
410
+ new_batch = self.get_new_batch_prefill()
411
+ if new_batch is not None:
412
+ # Run a new prefill batch
413
+ # replace run_batch with the uncommented line to use pytorch profiler
414
+ # result = pytorch_profile(
415
+ # "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
416
+ # )
417
+ result = self.run_batch(new_batch)
418
+ self.process_batch_result(new_batch, result)
419
+ else:
420
+ if self.running_batch is not None:
421
+ # Run a few decode batches continuously for reducing overhead
422
+ for _ in range(global_config.num_continue_decode_steps):
423
+ batch = self.get_new_batch_decode()
424
+
425
+ if batch:
426
+ # replace run_batch with the uncommented line to use pytorch profiler
427
+ # result = pytorch_profile(
428
+ # "profile_decode_step",
429
+ # self.run_batch,
430
+ # batch,
431
+ # data_size=len(batch.reqs),
432
+ # )
433
+ result = self.run_batch(batch)
434
+ self.process_batch_result(batch, result)
435
+
436
+ if self.running_batch is None:
437
+ break
438
+
439
+ if self.out_pyobjs and self.running_batch.has_stream:
440
+ break
441
+ else:
442
+ self.check_memory()
443
+ self.new_token_ratio = global_config.init_new_token_ratio
444
+
445
+ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
446
+ # Handle the cases where prefill is not allowed
447
+ if (
448
+ self.batch_is_full or len(self.waiting_queue) == 0
449
+ ) and self.current_inflight_req is None:
450
+ return None
451
+
452
+ running_bs = (
453
+ len(self.running_batch.reqs) if self.running_batch is not None else 0
454
+ )
455
+ if running_bs >= self.max_running_requests:
456
+ self.batch_is_full = True
457
+ return None
458
+
459
+ # Get priority queue
460
+ prefix_computed = self.policy.calc_priority(self.waiting_queue)
461
+
462
+ # Prefill policy
463
+ num_mixed_running = running_bs if self.is_mixed_chunk else 0
464
+ adder = PrefillAdder(
465
+ self.tree_cache,
466
+ self.running_batch,
467
+ self.new_token_ratio,
468
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
469
+ self.max_prefill_tokens,
470
+ self.chunked_prefill_size,
471
+ num_mixed_running,
472
+ )
473
+
474
+ has_inflight = self.current_inflight_req is not None
475
+ if self.current_inflight_req is not None:
476
+ self.current_inflight_req.init_next_round_input(
477
+ None if prefix_computed else self.tree_cache
478
+ )
479
+ self.current_inflight_req = adder.add_inflight_req(
480
+ self.current_inflight_req
481
+ )
482
+
483
+ if self.lora_paths is not None:
484
+ lora_set = (
485
+ set([req.lora_path for req in self.running_batch.reqs])
486
+ if self.running_batch is not None
487
+ else set([])
488
+ )
489
+
490
+ for req in self.waiting_queue:
491
+ if (
492
+ self.lora_paths is not None
493
+ and len(
494
+ lora_set
495
+ | set([req.lora_path for req in adder.can_run_list])
496
+ | set([req.lora_path])
497
+ )
498
+ > self.max_loras_per_batch
499
+ ):
500
+ self.batch_is_full = True
501
+ break
502
+
503
+ if running_bs + len(adder.can_run_list) >= self.max_running_requests:
504
+ self.batch_is_full = True
505
+ break
506
+
507
+ req.init_next_round_input(None if prefix_computed else self.tree_cache)
508
+ res = adder.add_one_req(req)
509
+ if res != AddReqResult.CONTINUE:
510
+ if res == AddReqResult.NO_TOKEN:
511
+ self.batch_is_full = True
512
+ break
513
+
514
+ can_run_list = adder.can_run_list
515
+
516
+ if adder.new_inflight_req is not None:
517
+ assert self.current_inflight_req is None
518
+ self.current_inflight_req = adder.new_inflight_req
519
+
520
+ if len(can_run_list) == 0:
521
+ return None
522
+
523
+ self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
524
+
525
+ # Print stats
526
+ if self.tp_rank == 0:
527
+ if isinstance(self.tree_cache, RadixCache):
528
+ self.tree_cache_metrics["total"] += (
529
+ adder.log_input_tokens + adder.log_hit_tokens
530
+ ) / 10**9
531
+ self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
532
+ tree_cache_hit_rate = (
533
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
534
+ )
535
+ else:
536
+ tree_cache_hit_rate = 0.0
537
+
538
+ num_used = self.max_total_num_tokens - (
539
+ self.token_to_kv_pool.available_size()
540
+ + self.tree_cache.evictable_size()
541
+ )
542
+
543
+ if num_mixed_running > 0:
544
+ logger.info(
545
+ f"Prefill batch"
546
+ f"(mixed #running-req: {num_mixed_running}). "
547
+ f"#new-seq: {len(can_run_list)}, "
548
+ f"#new-token: {adder.log_input_tokens}, "
549
+ f"#cached-token: {adder.log_hit_tokens}, "
550
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
551
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
552
+ f"#queue-req: {len(self.waiting_queue) + has_inflight}"
553
+ )
554
+ else:
555
+ logger.info(
556
+ f"Prefill batch. "
557
+ f"#new-seq: {len(can_run_list)}, "
558
+ f"#new-token: {adder.log_input_tokens}, "
559
+ f"#cached-token: {adder.log_hit_tokens}, "
560
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
561
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
562
+ f"#running-req: {running_bs}, "
563
+ f"#queue-req: {len(self.waiting_queue) + has_inflight}"
564
+ )
565
+
566
+ # Create a new batch
567
+ new_batch = ScheduleBatch.init_new(
568
+ can_run_list,
569
+ self.req_to_token_pool,
570
+ self.token_to_kv_pool,
571
+ self.tree_cache,
572
+ )
573
+ new_batch.prepare_for_extend(self.model_config.vocab_size)
574
+
575
+ # Mixed-style chunked prefill
576
+ decoding_reqs = []
577
+ if self.is_mixed_chunk and self.running_batch is not None:
578
+ self.running_batch.prepare_for_decode()
579
+ new_batch.mix_with_running(self.running_batch)
580
+ decoding_reqs = self.running_batch.reqs
581
+ self.running_batch = None
582
+ new_batch.decoding_reqs = decoding_reqs
583
+
584
+ return new_batch
585
+
586
+ def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
587
+ batch = self.running_batch
588
+
589
+ # Check if decode out of memory
590
+ if not batch.check_decode_mem():
591
+ old_ratio = self.new_token_ratio
592
+
593
+ retracted_reqs, new_token_ratio = batch.retract_decode()
594
+ self.new_token_ratio = new_token_ratio
595
+
596
+ logger.info(
597
+ "Decode out of memory happened. "
598
+ f"#retracted_reqs: {len(retracted_reqs)}, "
599
+ f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
600
+ )
601
+ self.waiting_queue.extend(retracted_reqs)
602
+ else:
603
+ self.new_token_ratio = max(
604
+ self.new_token_ratio - self.new_token_ratio_decay,
605
+ self.min_new_token_ratio,
606
+ )
607
+
608
+ # Check for jump-forward
609
+ if not self.disable_regex_jump_forward:
610
+ jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
611
+ self.waiting_queue.extend(jump_forward_reqs)
612
+ if batch.is_empty():
613
+ return None
614
+
615
+ # Update batch tensors
616
+ batch.prepare_for_decode()
617
+ return batch
618
+
619
+ def run_batch(self, batch: ScheduleBatch):
620
+ if self.is_generation:
621
+ if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
622
+ model_worker_batch = batch.get_model_worker_batch()
623
+ logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
624
+ model_worker_batch
625
+ )
626
+ else:
627
+ logits_output = None
628
+ if self.tokenizer is not None:
629
+ next_token_ids = torch.full(
630
+ (batch.batch_size(),), self.tokenizer.eos_token_id
631
+ )
632
+ else:
633
+ next_token_ids = torch.full((batch.batch_size(),), 0)
634
+ return logits_output, next_token_ids
635
+ else: # embedding or reward model
636
+ assert batch.extend_num_tokens != 0
637
+ model_worker_batch = batch.get_model_worker_batch()
638
+ embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
639
+ return embeddings
640
+
641
+ def process_batch_result(self, batch: ScheduleBatch, result):
642
+ if batch.forward_mode.is_decode():
643
+ self.process_batch_result_decode(batch, result)
644
+ else:
645
+ self.process_batch_result_prefill(batch, result)
646
+
647
+ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
648
+ if self.is_generation:
649
+ logits_output, next_token_ids = result
650
+ batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
651
+ next_token_ids
652
+ )
653
+
654
+ if logits_output:
655
+ # Move logprobs to cpu
656
+ if logits_output.next_token_logprobs is not None:
657
+ logits_output.next_token_logprobs = (
658
+ logits_output.next_token_logprobs[
659
+ torch.arange(
660
+ len(next_token_ids), device=next_token_ids.device
661
+ ),
662
+ next_token_ids,
663
+ ].tolist()
664
+ )
665
+ logits_output.input_token_logprobs = (
666
+ logits_output.input_token_logprobs.tolist()
667
+ )
668
+ logits_output.normalized_prompt_logprobs = (
669
+ logits_output.normalized_prompt_logprobs.tolist()
670
+ )
671
+
672
+ next_token_ids = next_token_ids.tolist()
673
+
674
+ # Check finish conditions
675
+ logprob_pt = 0
676
+ for i, req in enumerate(batch.reqs):
677
+ if req is not self.current_inflight_req:
678
+ # Inflight reqs' prefill is not finished
679
+ req.completion_tokens_wo_jump_forward += 1
680
+ req.output_ids.append(next_token_ids[i])
681
+ req.check_finished()
682
+
683
+ if req.regex_fsm is not None:
684
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
685
+ req.regex_fsm_state, next_token_ids[i]
686
+ )
687
+
688
+ if req.finished():
689
+ self.tree_cache.cache_finished_req(req)
690
+ elif req not in batch.decoding_reqs:
691
+ # To reduce overhead, only cache prefill reqs
692
+ self.tree_cache.cache_unfinished_req(req)
693
+
694
+ if req is self.current_inflight_req:
695
+ # Inflight request would get a new req idx
696
+ self.req_to_token_pool.free(req.req_pool_idx)
697
+
698
+ if req.return_logprob:
699
+ logprob_pt += self.add_logprob_return_values(
700
+ i, req, logprob_pt, next_token_ids, logits_output
701
+ )
702
+ else: # embedding or reward model
703
+ assert batch.extend_num_tokens != 0
704
+ embeddings = result
705
+
706
+ # Check finish conditions
707
+ for i, req in enumerate(batch.reqs):
708
+ req.embedding = embeddings[i]
709
+ if req is not self.current_inflight_req:
710
+ # Inflight reqs' prefill is not finished
711
+ # dummy output token for embedding models
712
+ req.output_ids.append(0)
713
+ req.check_finished()
714
+
715
+ if req.finished():
716
+ self.tree_cache.cache_finished_req(req)
717
+ else:
718
+ self.tree_cache.cache_unfinished_req(req)
719
+
720
+ if req is self.current_inflight_req:
721
+ # Inflight request would get a new req idx
722
+ self.req_to_token_pool.free(req.req_pool_idx)
723
+
724
+ self.handle_finished_requests(batch)
725
+
726
+ if not batch.is_empty():
727
+ if self.running_batch is None:
728
+ self.running_batch = batch
729
+ else:
730
+ self.running_batch.merge_batch(batch)
731
+
732
+ def process_batch_result_decode(self, batch: ScheduleBatch, result):
733
+ logits_output, next_token_ids = result
734
+ batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
735
+ next_token_ids
736
+ )
737
+ self.num_generated_tokens += len(batch.reqs)
738
+
739
+ # Move logprobs to cpu
740
+ if logits_output.next_token_logprobs is not None:
741
+ next_token_logprobs = logits_output.next_token_logprobs[
742
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
743
+ next_token_ids,
744
+ ].tolist()
745
+
746
+ next_token_ids = next_token_ids.tolist()
747
+
748
+ # Check finish condition
749
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
750
+ req.completion_tokens_wo_jump_forward += 1
751
+ req.output_ids.append(next_token_id)
752
+ req.check_finished()
753
+
754
+ if req.regex_fsm is not None:
755
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
756
+ req.regex_fsm_state, next_token_id
757
+ )
758
+
759
+ if req.finished():
760
+ self.tree_cache.cache_finished_req(req)
761
+
762
+ if req.return_logprob:
763
+ req.output_token_logprobs.append(
764
+ (next_token_logprobs[i], next_token_id)
765
+ )
766
+ if req.top_logprobs_num > 0:
767
+ req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
768
+
769
+ self.handle_finished_requests(batch)
770
+
771
+ self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
772
+ if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
773
+ self.print_decode_stats()
774
+
775
+ if self.running_batch.is_empty():
776
+ self.running_batch = None
777
+
778
+ def add_logprob_return_values(
779
+ self,
780
+ i: int,
781
+ req: Req,
782
+ pt: int,
783
+ next_token_ids: List[int],
784
+ output: LogitsProcessorOutput,
785
+ ):
786
+ """Attach logprobs to the return values."""
787
+ req.output_token_logprobs.append(
788
+ (output.next_token_logprobs[i], next_token_ids[i])
789
+ )
790
+
791
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
792
+ num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
793
+
794
+ if req.normalized_prompt_logprob is None:
795
+ req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
796
+
797
+ if req.input_token_logprobs is None:
798
+ input_token_logprobs = output.input_token_logprobs[
799
+ pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
800
+ ]
801
+ input_token_ids = req.fill_ids[
802
+ len(req.fill_ids)
803
+ - num_input_logprobs
804
+ + 1 : len(req.fill_ids)
805
+ - req.last_update_decode_tokens
806
+ ]
807
+ req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
808
+
809
+ if (
810
+ req.logprob_start_len == 0
811
+ ): # The first token does not have logprob, pad it.
812
+ req.input_token_logprobs = [
813
+ (None, req.fill_ids[0])
814
+ ] + req.input_token_logprobs
815
+
816
+ if req.last_update_decode_tokens != 0:
817
+ # Some decode tokens are re-computed in an extend batch
818
+ req.output_token_logprobs.extend(
819
+ list(
820
+ zip(
821
+ output.input_token_logprobs[
822
+ pt
823
+ + num_input_logprobs
824
+ - 1
825
+ - req.last_update_decode_tokens : pt
826
+ + num_input_logprobs
827
+ - 1
828
+ ],
829
+ req.fill_ids[
830
+ len(req.fill_ids)
831
+ - req.last_update_decode_tokens : len(req.fill_ids)
832
+ ],
833
+ )
834
+ )
835
+ )
836
+
837
+ if req.top_logprobs_num > 0:
838
+ if req.input_top_logprobs is None:
839
+ req.input_top_logprobs = output.input_top_logprobs[i]
840
+ if req.logprob_start_len == 0:
841
+ req.input_top_logprobs = [None] + req.input_top_logprobs
842
+
843
+ if req.last_update_decode_tokens != 0:
844
+ req.output_top_logprobs.extend(
845
+ output.input_top_logprobs[i][-req.last_update_decode_tokens :]
846
+ )
847
+ req.output_top_logprobs.append(output.output_top_logprobs[i])
848
+
849
+ return num_input_logprobs
850
+
851
+ def handle_finished_requests(self, batch: ScheduleBatch):
852
+ output_rids = []
853
+ output_meta_info = []
854
+ output_finished_reason: List[BaseFinishReason] = []
855
+ if self.is_generation:
856
+ output_vids = []
857
+ decoded_texts = []
858
+ output_read_ids = []
859
+ output_read_offsets = []
860
+ output_skip_special_tokens = []
861
+ output_spaces_between_special_tokens = []
862
+ else: # embedding or reward model
863
+ output_embeddings = []
864
+ unfinished_indices = []
865
+
866
+ for i, req in enumerate(batch.reqs):
867
+ if not req.finished() and req is not self.current_inflight_req:
868
+ unfinished_indices.append(i)
869
+ else:
870
+ self.batch_is_full = False
871
+
872
+ if req.finished() or (
873
+ req.stream
874
+ and (
875
+ self.decode_forward_ct % self.stream_interval == 0
876
+ or len(req.output_ids) == 1
877
+ )
878
+ ):
879
+ output_rids.append(req.rid)
880
+ output_finished_reason.append(req.finished_reason)
881
+ if self.is_generation:
882
+ output_vids.append(req.vid)
883
+ decoded_texts.append(req.decoded_text)
884
+ read_ids, read_offset = req.init_incremental_detokenize()
885
+ output_read_ids.append(read_ids)
886
+ output_read_offsets.append(read_offset)
887
+ output_skip_special_tokens.append(
888
+ req.sampling_params.skip_special_tokens
889
+ )
890
+ output_spaces_between_special_tokens.append(
891
+ req.sampling_params.spaces_between_special_tokens
892
+ )
893
+
894
+ meta_info = {
895
+ "prompt_tokens": len(req.origin_input_ids),
896
+ "completion_tokens": len(req.output_ids),
897
+ "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
898
+ "finish_reason": (
899
+ req.finished_reason.to_json()
900
+ if req.finished_reason is not None
901
+ else None
902
+ ),
903
+ }
904
+ if req.return_logprob:
905
+ (
906
+ meta_info["input_token_logprobs"],
907
+ meta_info["output_token_logprobs"],
908
+ meta_info["input_top_logprobs"],
909
+ meta_info["output_top_logprobs"],
910
+ meta_info["normalized_prompt_logprob"],
911
+ ) = (
912
+ req.input_token_logprobs,
913
+ req.output_token_logprobs,
914
+ req.input_top_logprobs,
915
+ req.output_top_logprobs,
916
+ req.normalized_prompt_logprob,
917
+ )
918
+ output_meta_info.append(meta_info)
919
+ else: # embedding or reward model
920
+ output_embeddings.append(req.embedding)
921
+ meta_info = {
922
+ "prompt_tokens": len(req.origin_input_ids),
923
+ }
924
+ output_meta_info.append(meta_info)
925
+
926
+ # Send to detokenizer
927
+ if output_rids:
928
+ if self.is_generation:
929
+ self.out_pyobjs.append(
930
+ BatchTokenIDOut(
931
+ output_rids,
932
+ output_vids,
933
+ decoded_texts,
934
+ output_read_ids,
935
+ output_read_offsets,
936
+ output_skip_special_tokens,
937
+ output_spaces_between_special_tokens,
938
+ output_meta_info,
939
+ output_finished_reason,
940
+ )
941
+ )
942
+ else: # embedding or reward model
943
+ self.out_pyobjs.append(
944
+ BatchEmbeddingOut(
945
+ output_rids,
946
+ output_embeddings,
947
+ output_meta_info,
948
+ output_finished_reason,
949
+ )
950
+ )
951
+
952
+ # Remove finished reqs: update batch tensors
953
+ batch.filter_batch(unfinished_indices)
954
+
955
+ def flush_cache(self):
956
+ if len(self.waiting_queue) == 0 and (
957
+ self.running_batch is None or len(self.running_batch.reqs) == 0
958
+ ):
959
+ self.tree_cache.reset()
960
+ self.tree_cache_metrics = {"total": 0, "hit": 0}
961
+ self.regex_fsm_cache.reset()
962
+ self.req_to_token_pool.clear()
963
+ self.token_to_kv_pool.clear()
964
+ torch.cuda.empty_cache()
965
+ logger.info("Cache flushed successfully!")
966
+ if_success = True
967
+ else:
968
+ logging.warning(
969
+ f"Cache not flushed because there are pending requests. "
970
+ f"#queue-req: {len(self.waiting_queue)}, "
971
+ f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
972
+ )
973
+ if_success = False
974
+ return if_success
975
+
976
+ def abort_request(self, recv_req: AbortReq):
977
+ # Delete requests in the waiting queue
978
+ to_del = None
979
+ for i, req in enumerate(self.waiting_queue):
980
+ if req.rid == recv_req.rid:
981
+ to_del = i
982
+ break
983
+
984
+ if to_del is not None:
985
+ del self.waiting_queue[to_del]
986
+
987
+ # Delete requests in the running batch
988
+ if self.running_batch:
989
+ for req in self.running_batch.reqs:
990
+ if req.rid == recv_req.rid:
991
+ req.finished_reason = FINISH_ABORT()
992
+ break
993
+
994
+ def update_weights(self, recv_req: UpdateWeightReqInput):
995
+ success, message = self.tp_worker.update_weights(recv_req)
996
+ if success:
997
+ flash_cache_success = self.flush_cache()
998
+ assert flash_cache_success, "Cache flush failed after updating weights"
999
+ else:
1000
+ logger.error(message)
1001
+ return success, message
1002
+
1003
+
1004
+ def run_scheduler_process(
1005
+ server_args: ServerArgs,
1006
+ port_args: PortArgs,
1007
+ gpu_id: int,
1008
+ tp_rank: int,
1009
+ pipe_writer,
1010
+ ):
1011
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
1012
+ suppress_other_loggers()
1013
+
1014
+ try:
1015
+ scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
1016
+ pipe_writer.send("ready")
1017
+ scheduler.event_loop()
1018
+ except Exception:
1019
+ msg = get_exception_traceback()
1020
+ logger.error(msg)
1021
+ kill_parent_process()