sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -17,60 +17,22 @@ limitations under the License.
17
17
 
18
18
  import json
19
19
  import logging
20
- import multiprocessing
21
- import os
22
- import pickle
23
- import time
24
- import warnings
25
- from typing import Any, List, Optional
26
20
 
27
- import torch
28
- import torch.distributed
29
- import torch.distributed as dist
30
-
31
- from sglang.global_config import global_config
32
21
  from sglang.srt.configs.model_config import ModelConfig
33
- from sglang.srt.constrained.fsm_cache import FSMCache
34
- from sglang.srt.constrained.jump_forward import JumpForwardCache
35
22
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
36
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
37
- from sglang.srt.managers.io_struct import (
38
- AbortReq,
39
- BatchEmbeddingOut,
40
- BatchTokenIDOut,
41
- FlushCacheReq,
42
- TokenizedEmbeddingReqInput,
43
- TokenizedGenerateReqInput,
44
- UpdateWeightReqInput,
45
- UpdateWeightReqOutput,
46
- )
47
- from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
48
- from sglang.srt.managers.schedule_batch import (
49
- FINISH_ABORT,
50
- BaseFinishReason,
51
- Req,
52
- ScheduleBatch,
53
- )
54
- from sglang.srt.mem_cache.chunk_cache import ChunkCache
55
- from sglang.srt.mem_cache.radix_cache import RadixCache
23
+ from sglang.srt.managers.io_struct import UpdateWeightReqInput
24
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
25
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
26
  from sglang.srt.model_executor.model_runner import ModelRunner
57
27
  from sglang.srt.server_args import ServerArgs
58
- from sglang.srt.utils import (
59
- configure_logger,
60
- is_multimodal_model,
61
- set_random_seed,
62
- suppress_other_loggers,
63
- )
64
- from sglang.utils import get_exception_traceback
28
+ from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
65
29
 
66
30
  logger = logging.getLogger(__name__)
67
31
 
68
32
 
69
- # Crash on warning if we are running CI tests
70
- crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
71
-
33
+ class TpModelWorker:
34
+ """A tensor parallel model worker."""
72
35
 
73
- class ModelTpServer:
74
36
  def __init__(
75
37
  self,
76
38
  gpu_id: int,
@@ -78,17 +40,8 @@ class ModelTpServer:
78
40
  server_args: ServerArgs,
79
41
  nccl_port: int,
80
42
  ):
81
- suppress_other_loggers()
82
-
83
- # Parse arguments
84
- self.gpu_id = gpu_id
43
+ # Parse args
85
44
  self.tp_rank = tp_rank
86
- self.tp_size = server_args.tp_size
87
- self.dp_size = server_args.dp_size
88
- self.schedule_policy = server_args.schedule_policy
89
- self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
90
- self.lora_paths = server_args.lora_paths
91
- self.max_loras_per_batch = server_args.max_loras_per_batch
92
45
 
93
46
  # Init model and tokenizer
94
47
  self.model_config = ModelConfig(
@@ -122,6 +75,8 @@ class ModelTpServer:
122
75
  tokenizer_mode=server_args.tokenizer_mode,
123
76
  trust_remote_code=server_args.trust_remote_code,
124
77
  )
78
+
79
+ # Profile number of tokens
125
80
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
126
81
  self.max_prefill_tokens = server_args.max_prefill_tokens
127
82
  self.max_running_requests = min(
@@ -138,888 +93,36 @@ class ModelTpServer:
138
93
  )
139
94
 
140
95
  # Sync random seed across TP workers
141
- server_args.random_seed = broadcast_recv_input(
96
+ self.random_seed = broadcast_pyobj(
142
97
  [server_args.random_seed],
143
98
  self.tp_rank,
144
99
  self.model_runner.tp_group.cpu_group,
145
100
  )[0]
146
- set_random_seed(server_args.random_seed)
147
-
148
- # Print debug info
149
- logger.info(
150
- f"max_total_num_tokens={self.max_total_num_tokens}, "
151
- f"max_prefill_tokens={self.max_prefill_tokens}, "
152
- f"max_running_requests={self.max_running_requests}, "
153
- f"context_len={self.model_config.context_len}"
154
- )
155
-
156
- # Init cache
157
- if (
158
- server_args.chunked_prefill_size is not None
159
- and server_args.disable_radix_cache
160
- ):
161
- self.tree_cache = ChunkCache(
162
- req_to_token_pool=self.model_runner.req_to_token_pool,
163
- token_to_kv_pool=self.model_runner.token_to_kv_pool,
164
- )
165
- else:
166
- self.tree_cache = RadixCache(
167
- req_to_token_pool=self.model_runner.req_to_token_pool,
168
- token_to_kv_pool=self.model_runner.token_to_kv_pool,
169
- disable=server_args.disable_radix_cache,
170
- )
171
- self.tree_cache_metrics = {"total": 0, "hit": 0}
172
- self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
173
- self.req_to_token_pool = self.model_runner.req_to_token_pool
174
- self.token_to_kv_pool = self.model_runner.token_to_kv_pool
175
-
176
- # Init running status
177
- self.waiting_queue: List[Req] = []
178
- self.running_batch: ScheduleBatch = None
179
- self.out_pyobjs = []
180
- self.decode_forward_ct = 0
181
- self.stream_interval = server_args.stream_interval
182
- self.num_generated_tokens = 0
183
- self.last_stats_tic = time.time()
184
-
185
- # Init chunked prefill
186
- self.chunked_prefill_size = server_args.chunked_prefill_size
187
- self.current_inflight_req = None
188
- self.is_mixed_chunk = (
189
- self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
190
- )
191
-
192
- # Init the FSM cache for constrained generation
193
- if not server_args.skip_tokenizer_init:
194
- self.regex_fsm_cache = FSMCache(
195
- server_args.tokenizer_path,
196
- {
197
- "tokenizer_mode": server_args.tokenizer_mode,
198
- "trust_remote_code": server_args.trust_remote_code,
199
- },
200
- skip_tokenizer_init=server_args.skip_tokenizer_init,
201
- constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
202
- )
203
- self.jump_forward_cache = JumpForwardCache()
204
-
205
- # Init new token estimation
206
- assert (
207
- server_args.schedule_conservativeness >= 0
208
- ), "Invalid schedule_conservativeness"
209
- self.min_new_token_ratio = min(
210
- global_config.base_min_new_token_ratio
211
- * server_args.schedule_conservativeness,
212
- 1.0,
213
- )
214
- self.new_token_ratio = self.min_new_token_ratio
215
- self.new_token_ratio_decay = global_config.new_token_ratio_decay
216
- self.do_not_get_new_batch = False
217
-
218
- @torch.inference_mode()
219
- def exposed_step(self, recv_reqs: List):
220
- try:
221
- # Recv requests
222
- for recv_req in recv_reqs:
223
- if isinstance(recv_req, TokenizedGenerateReqInput):
224
- self.handle_generate_request(recv_req)
225
- self.do_not_get_new_batch = False
226
- elif isinstance(recv_req, TokenizedEmbeddingReqInput):
227
- self.handle_embedding_request(recv_req)
228
- self.do_not_get_new_batch = False
229
- elif isinstance(recv_req, FlushCacheReq):
230
- self.flush_cache()
231
- elif isinstance(recv_req, AbortReq):
232
- self.abort_request(recv_req)
233
- elif isinstance(recv_req, UpdateWeightReqInput):
234
- success, message = self.update_weights(recv_req)
235
- self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
236
- else:
237
- raise ValueError(f"Invalid request: {recv_req}")
238
-
239
- # Forward
240
- self.forward_step()
241
- except Exception:
242
- logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
243
- raise
244
-
245
- # Return results
246
- ret = self.out_pyobjs
247
- self.out_pyobjs = []
248
- return ret
249
-
250
- def forward_step(self):
251
- if self.do_not_get_new_batch and self.current_inflight_req is None:
252
- new_batch = None
253
- else:
254
- new_batch = self.get_new_prefill_batch()
255
- self.do_not_get_new_batch = False
256
-
257
- if new_batch is not None:
258
- # Run a new prefill batch
259
- self.forward_prefill_batch(new_batch)
101
+ set_random_seed(self.random_seed)
260
102
 
261
- if not new_batch.is_empty():
262
- if self.running_batch is None:
263
- self.running_batch = new_batch
264
- else:
265
- self.running_batch.merge(new_batch)
266
- else:
267
- # Run a decode batch
268
- if self.running_batch is not None:
269
- # Run a few decode batches continuously for reducing overhead
270
- for _ in range(global_config.num_continue_decode_steps):
271
- self.num_generated_tokens += len(self.running_batch.reqs)
272
- self.forward_decode_batch(self.running_batch)
273
-
274
- # Print stats
275
- if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
276
- self.print_decode_stats()
277
-
278
- if self.running_batch.is_empty():
279
- self.running_batch = None
280
- break
281
-
282
- if self.out_pyobjs and self.running_batch.has_stream:
283
- break
284
- else:
285
- self.check_memory()
286
- self.new_token_ratio = global_config.init_new_token_ratio
287
-
288
- def print_decode_stats(self):
289
- num_used = self.max_total_num_tokens - (
290
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
291
- )
292
- throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
293
- self.num_generated_tokens = 0
294
- self.last_stats_tic = time.time()
295
- logger.info(
296
- f"Decode batch. "
297
- f"#running-req: {len(self.running_batch.reqs)}, "
298
- f"#token: {num_used}, "
299
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
300
- f"gen throughput (token/s): {throughput:.2f}, "
301
- f"#queue-req: {len(self.waiting_queue)}"
302
- )
303
-
304
- def check_memory(self):
305
- available_size = (
306
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
307
- )
308
- if available_size != self.max_total_num_tokens:
309
- warnings.warn(
310
- "Warning: "
311
- f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
312
- "KV cache pool leak detected!"
313
- )
314
- exit(1) if crash_on_warning else None
315
-
316
- if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
317
- warnings.warn(
318
- "Warning: "
319
- f"available req slots={len(self.req_to_token_pool.free_slots)}, "
320
- f"total slots={self.req_to_token_pool.size}\n"
321
- "Memory pool leak detected!"
322
- )
323
- exit(1) if crash_on_warning else None
324
-
325
- def handle_generate_request(
326
- self,
327
- recv_req: TokenizedGenerateReqInput,
328
- ):
329
- if isinstance(recv_req, TokenizedGenerateReqInput):
330
- req = Req(
331
- recv_req.rid,
332
- recv_req.input_text,
333
- recv_req.input_ids,
334
- lora_path=recv_req.lora_path,
335
- )
336
- else:
337
- req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
338
- req.tokenizer = self.tokenizer
339
- req.sampling_params = recv_req.sampling_params
340
- req.pixel_values = recv_req.pixel_values
341
- if req.pixel_values is not None:
342
- # Use image hash as fake token_ids, which is then used
343
- # for prefix matching
344
- image_hash = hash(tuple(recv_req.image_hashes))
345
- req.pad_value = [
346
- (image_hash) % self.model_config.vocab_size,
347
- (image_hash >> 16) % self.model_config.vocab_size,
348
- (image_hash >> 32) % self.model_config.vocab_size,
349
- (image_hash >> 64) % self.model_config.vocab_size,
350
- ]
351
- req.image_sizes = recv_req.image_sizes
352
- (
353
- req.origin_input_ids,
354
- req.image_offsets,
355
- ) = self.model_runner.model.pad_input_ids(
356
- req.origin_input_ids_unpadded,
357
- req.pad_value,
358
- req.pixel_values,
359
- req.image_sizes,
360
- )
361
- # Only when pixel values is not None we have modalities
362
- req.modalities = recv_req.modalites
363
- req.return_logprob = recv_req.return_logprob
364
- req.top_logprobs_num = recv_req.top_logprobs_num
365
- req.stream = recv_req.stream
366
- req.logprob_start_len = recv_req.logprob_start_len
367
-
368
- if req.logprob_start_len == -1:
369
- # By default, only return the logprobs for output tokens
370
- req.logprob_start_len = len(recv_req.input_ids) - 1
371
-
372
- # Init regex FSM
373
- if (
374
- req.sampling_params.json_schema is not None
375
- or req.sampling_params.regex is not None
376
- ):
377
- if req.sampling_params.json_schema is not None:
378
- req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
379
- ("json", req.sampling_params.json_schema)
380
- )
381
- elif req.sampling_params.regex is not None:
382
- req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
383
- ("regex", req.sampling_params.regex)
384
- )
385
- if not self.disable_regex_jump_forward:
386
- req.jump_forward_map = self.jump_forward_cache.query(
387
- computed_regex_string
388
- )
389
-
390
- # Truncate prompts that are too long
391
- if len(req.origin_input_ids) >= self.max_req_input_len:
392
- logger.warning(
393
- "Request length is longer than the KV cache pool size or "
394
- "the max context length. Truncated!!!"
395
- )
396
- req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
397
- req.sampling_params.max_new_tokens = min(
398
- (
399
- req.sampling_params.max_new_tokens
400
- if req.sampling_params.max_new_tokens is not None
401
- else 1 << 30
402
- ),
403
- self.max_req_input_len - 1 - len(req.origin_input_ids),
404
- )
405
-
406
- self.waiting_queue.append(req)
407
-
408
- def handle_embedding_request(
409
- self,
410
- recv_req: TokenizedEmbeddingReqInput,
411
- ):
412
- req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
413
- req.tokenizer = self.tokenizer
414
- req.sampling_params = recv_req.sampling_params
415
-
416
- # Truncate prompts that are too long
417
- if len(req.origin_input_ids) >= self.max_req_input_len:
418
- logger.warning(
419
- "Request length is longer than the KV cache pool size or "
420
- "the max context length. Truncated!!!"
421
- )
422
- req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
423
-
424
- self.waiting_queue.append(req)
425
-
426
- def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
427
- running_bs = (
428
- len(self.running_batch.reqs) if self.running_batch is not None else 0
429
- )
430
- if running_bs >= self.max_running_requests:
431
- return None
432
-
433
- # Get priority queue
434
- prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
435
-
436
- num_mixed_running = running_bs if self.is_mixed_chunk else 0
437
-
438
- adder = PrefillAdder(
439
- self.tree_cache,
440
- self.running_batch,
441
- self.new_token_ratio,
442
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
103
+ def get_token_and_memory_info(self):
104
+ return (
105
+ self.max_total_num_tokens,
443
106
  self.max_prefill_tokens,
444
- self.chunked_prefill_size,
445
- num_mixed_running,
446
- )
447
-
448
- has_inflight = self.current_inflight_req is not None
449
- if self.current_inflight_req is not None:
450
- self.current_inflight_req.init_next_round_input(
451
- None if prefix_computed else self.tree_cache
452
- )
453
- self.current_inflight_req = adder.add_inflight_req(
454
- self.current_inflight_req
455
- )
456
-
457
- if self.lora_paths is not None:
458
- lora_set = (
459
- set([req.lora_path for req in self.running_batch.reqs])
460
- if self.running_batch is not None
461
- else set([])
462
- )
463
-
464
- for req in self.waiting_queue:
465
- if (
466
- self.lora_paths is not None
467
- and len(
468
- lora_set
469
- | set([req.lora_path for req in adder.can_run_list])
470
- | set([req.lora_path])
471
- )
472
- > self.max_loras_per_batch
473
- ):
474
- break
475
-
476
- if adder.no_remaining_tokens():
477
- break
478
- req.init_next_round_input(None if prefix_computed else self.tree_cache)
479
- res = adder.add_one_req(req)
480
- if (
481
- not res
482
- or running_bs + len(adder.can_run_list) >= self.max_running_requests
483
- ):
484
- break
485
-
486
- can_run_list = adder.can_run_list
487
-
488
- if adder.new_inflight_req is not None:
489
- assert self.current_inflight_req is None
490
- self.current_inflight_req = adder.new_inflight_req
491
-
492
- if len(can_run_list) == 0:
493
- return None
494
-
495
- # Print stats
496
- if self.tp_rank == 0:
497
- if isinstance(self.tree_cache, RadixCache):
498
- self.tree_cache_metrics["total"] += (
499
- adder.log_input_tokens + adder.log_hit_tokens
500
- ) / 10**9
501
- self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
502
- tree_cache_hit_rate = (
503
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
504
- )
505
- else:
506
- tree_cache_hit_rate = 0.0
507
-
508
- num_used = self.max_total_num_tokens - (
509
- self.token_to_kv_pool.available_size()
510
- + self.tree_cache.evictable_size()
511
- )
512
-
513
- if num_mixed_running > 0:
514
- logger.info(
515
- f"Prefill batch"
516
- f"(mixed #running-req: {num_mixed_running}). "
517
- f"#new-seq: {len(can_run_list)}, "
518
- f"#new-token: {adder.log_input_tokens}, "
519
- f"#cached-token: {adder.log_hit_tokens}, "
520
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
521
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
522
- f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
523
- )
524
- else:
525
- logger.info(
526
- f"Prefill batch. "
527
- f"#new-seq: {len(can_run_list)}, "
528
- f"#new-token: {adder.log_input_tokens}, "
529
- f"#cached-token: {adder.log_hit_tokens}, "
530
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
531
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
532
- f"#running-req: {running_bs}, "
533
- f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
534
- )
535
-
536
- # Return the new batch
537
- new_batch = ScheduleBatch.init_new(
538
- can_run_list,
539
- self.req_to_token_pool,
540
- self.token_to_kv_pool,
541
- self.tree_cache,
542
- )
543
- self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
544
- return new_batch
545
-
546
- def forward_prefill_batch(self, batch: ScheduleBatch):
547
- # Build batch tensors
548
- batch.prepare_for_extend(self.model_config.vocab_size)
549
-
550
- decoding_reqs = []
551
- if self.is_mixed_chunk and self.running_batch is not None:
552
- self.running_batch.prepare_for_decode()
553
- batch.mix_with_running(self.running_batch)
554
- decoding_reqs = self.running_batch.reqs
555
- self.running_batch = None
556
-
557
- if self.model_runner.is_generation:
558
- # Forward and sample the next tokens
559
- if batch.extend_num_tokens != 0:
560
- logits_output = self.model_runner.forward(batch)
561
- next_token_ids = self.model_runner.sample(logits_output, batch)
562
-
563
- batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
564
- next_token_ids
565
- )
566
-
567
- # Move logprobs to cpu
568
- if logits_output.next_token_logprobs is not None:
569
- logits_output.next_token_logprobs = (
570
- logits_output.next_token_logprobs[
571
- torch.arange(
572
- len(next_token_ids), device=next_token_ids.device
573
- ),
574
- next_token_ids,
575
- ].tolist()
576
- )
577
- logits_output.input_token_logprobs = (
578
- logits_output.input_token_logprobs.tolist()
579
- )
580
- logits_output.normalized_prompt_logprobs = (
581
- logits_output.normalized_prompt_logprobs.tolist()
582
- )
583
-
584
- next_token_ids = next_token_ids.tolist()
585
- else:
586
- if self.tokenizer is None:
587
- next_token_ids = []
588
- for req in batch.reqs:
589
- next_token_ids.append(
590
- next(iter(req.sampling_params.stop_token_ids))
591
- )
592
- else:
593
- next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
594
-
595
- # Check finish conditions
596
- logprob_pt = 0
597
- for i, req in enumerate(batch.reqs):
598
- if req is not self.current_inflight_req:
599
- # Inflight reqs' prefill is not finished
600
- req.completion_tokens_wo_jump_forward += 1
601
- req.output_ids.append(next_token_ids[i])
602
- req.check_finished()
603
-
604
- if req.regex_fsm is not None:
605
- req.regex_fsm_state = req.regex_fsm.get_next_state(
606
- req.regex_fsm_state, next_token_ids[i]
607
- )
608
-
609
- if req.finished():
610
- self.tree_cache.cache_finished_req(req)
611
- elif req not in decoding_reqs:
612
- # To reduce overhead, only cache prefill reqs
613
- self.tree_cache.cache_unfinished_req(req)
614
-
615
- if req is self.current_inflight_req:
616
- # Inflight request would get a new req idx
617
- self.req_to_token_pool.free(req.req_pool_idx)
618
-
619
- if req.return_logprob:
620
- logprob_pt += self.add_logprob_return_values(
621
- i, req, logprob_pt, next_token_ids, logits_output
622
- )
623
- else:
624
- assert batch.extend_num_tokens != 0
625
- logits_output = self.model_runner.forward(batch)
626
- embeddings = logits_output.embeddings.tolist()
627
-
628
- # Check finish conditions
629
- for i, req in enumerate(batch.reqs):
630
- req.embedding = embeddings[i]
631
- if req is not self.current_inflight_req:
632
- # Inflight reqs' prefill is not finished
633
- # dummy output token for embedding models
634
- req.output_ids.append(0)
635
- req.check_finished()
636
-
637
- if req.finished():
638
- self.tree_cache.cache_finished_req(req)
639
- else:
640
- self.tree_cache.cache_unfinished_req(req)
641
-
642
- if req is self.current_inflight_req:
643
- # Inflight request would get a new req idx
644
- self.req_to_token_pool.free(req.req_pool_idx)
645
-
646
- self.handle_finished_requests(batch)
647
-
648
- def add_logprob_return_values(
649
- self,
650
- i: int,
651
- req: Req,
652
- pt: int,
653
- next_token_ids: List[int],
654
- output: LogitsProcessorOutput,
655
- ):
656
- """Attach logprobs to the return values."""
657
- req.output_token_logprobs.append(
658
- (output.next_token_logprobs[i], next_token_ids[i])
107
+ self.max_running_requests,
108
+ self.max_req_input_len,
109
+ self.random_seed,
659
110
  )
660
111
 
661
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
662
- num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
663
-
664
- if req.normalized_prompt_logprob is None:
665
- req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
112
+ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
113
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
114
+ logits_output = self.model_runner.forward(forward_batch)
115
+ next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
116
+ return logits_output, next_token_ids
666
117
 
667
- if req.input_token_logprobs is None:
668
- input_token_logprobs = output.input_token_logprobs[
669
- pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
670
- ]
671
- input_token_ids = req.fill_ids[
672
- len(req.fill_ids)
673
- - num_input_logprobs
674
- + 1 : len(req.fill_ids)
675
- - req.last_update_decode_tokens
676
- ]
677
- req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
118
+ def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
119
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
120
+ logits_output = self.model_runner.forward(forward_batch)
121
+ embeddings = logits_output.embeddings.tolist()
122
+ return embeddings
678
123
 
679
- if (
680
- req.logprob_start_len == 0
681
- ): # The first token does not have logprob, pad it.
682
- req.input_token_logprobs = [
683
- (None, req.fill_ids[0])
684
- ] + req.input_token_logprobs
685
-
686
- if req.last_update_decode_tokens != 0:
687
- # Some decode tokens are re-computed in an extend batch
688
- req.output_token_logprobs.extend(
689
- list(
690
- zip(
691
- output.input_token_logprobs[
692
- pt
693
- + num_input_logprobs
694
- - 1
695
- - req.last_update_decode_tokens : pt
696
- + num_input_logprobs
697
- - 1
698
- ],
699
- req.fill_ids[
700
- len(req.fill_ids)
701
- - req.last_update_decode_tokens : len(req.fill_ids)
702
- ],
703
- )
704
- )
705
- )
706
-
707
- if req.top_logprobs_num > 0:
708
- if req.input_top_logprobs is None:
709
- req.input_top_logprobs = output.input_top_logprobs[i]
710
- if req.logprob_start_len == 0:
711
- req.input_top_logprobs = [None] + req.input_top_logprobs
712
-
713
- if req.last_update_decode_tokens != 0:
714
- req.output_top_logprobs.extend(
715
- output.input_top_logprobs[i][-req.last_update_decode_tokens :]
716
- )
717
- req.output_top_logprobs.append(output.output_top_logprobs[i])
718
-
719
- return num_input_logprobs
720
-
721
- def forward_decode_batch(self, batch: ScheduleBatch):
722
- # Check if decode out of memory
723
- if not batch.check_decode_mem():
724
- old_ratio = self.new_token_ratio
725
-
726
- retracted_reqs, new_token_ratio = batch.retract_decode()
727
- self.new_token_ratio = new_token_ratio
728
-
729
- logger.info(
730
- "Decode out of memory happened. "
731
- f"#retracted_reqs: {len(retracted_reqs)}, "
732
- f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
733
- )
734
- self.waiting_queue.extend(retracted_reqs)
735
- else:
736
- self.new_token_ratio = max(
737
- self.new_token_ratio - self.new_token_ratio_decay,
738
- self.min_new_token_ratio,
739
- )
740
-
741
- if not self.disable_regex_jump_forward:
742
- # Check for jump-forward
743
- jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
744
- self.waiting_queue.extend(jump_forward_reqs)
745
- if batch.is_empty():
746
- return
747
-
748
- # Update batch tensors
749
- self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
750
- batch.prepare_for_decode()
751
-
752
- # Forward and sample the next tokens
753
- logits_output = self.model_runner.forward(batch)
754
- next_token_ids = self.model_runner.sample(logits_output, batch)
755
- batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
756
- next_token_ids
757
- )
758
-
759
- # Move logprobs to cpu
760
- if logits_output.next_token_logprobs is not None:
761
- next_token_logprobs = logits_output.next_token_logprobs[
762
- torch.arange(len(next_token_ids), device=next_token_ids.device),
763
- next_token_ids,
764
- ].tolist()
765
-
766
- next_token_ids = next_token_ids.tolist()
767
-
768
- # Check finish condition
769
- has_finished = False
770
- for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
771
- req.completion_tokens_wo_jump_forward += 1
772
- req.output_ids.append(next_token_id)
773
- req.check_finished()
774
-
775
- if req.regex_fsm is not None:
776
- req.regex_fsm_state = req.regex_fsm.get_next_state(
777
- req.regex_fsm_state, next_token_id
778
- )
779
-
780
- if req.finished():
781
- self.tree_cache.cache_finished_req(req)
782
- has_finished = True
783
-
784
- if req.return_logprob:
785
- req.output_token_logprobs.append(
786
- (next_token_logprobs[i], next_token_id)
787
- )
788
- if req.top_logprobs_num > 0:
789
- req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
790
-
791
- if not has_finished:
792
- self.do_not_get_new_batch = True
793
-
794
- self.handle_finished_requests(batch)
795
-
796
- def handle_finished_requests(self, batch: ScheduleBatch):
797
- output_rids = []
798
- output_meta_info = []
799
- output_finished_reason: List[BaseFinishReason] = []
800
- if self.model_runner.is_generation:
801
- output_vids = []
802
- decoded_texts = []
803
- output_read_ids = []
804
- output_read_offsets = []
805
- output_skip_special_tokens = []
806
- output_spaces_between_special_tokens = []
807
- else: # for embedding model
808
- output_embeddings = []
809
- unfinished_indices = []
810
-
811
- for i, req in enumerate(batch.reqs):
812
- if not req.finished() and req is not self.current_inflight_req:
813
- unfinished_indices.append(i)
814
-
815
- if req.finished() or (
816
- req.stream
817
- and (
818
- self.decode_forward_ct % self.stream_interval == 0
819
- or len(req.output_ids) == 1
820
- )
821
- ):
822
- output_rids.append(req.rid)
823
- output_finished_reason.append(req.finished_reason)
824
- if self.model_runner.is_generation:
825
- output_vids.append(req.vid)
826
- decoded_texts.append(req.decoded_text)
827
- read_ids, read_offset = req.init_incremental_detokenize()
828
- output_read_ids.append(read_ids)
829
- output_read_offsets.append(read_offset)
830
- output_skip_special_tokens.append(
831
- req.sampling_params.skip_special_tokens
832
- )
833
- output_spaces_between_special_tokens.append(
834
- req.sampling_params.spaces_between_special_tokens
835
- )
836
-
837
- meta_info = {
838
- "prompt_tokens": len(req.origin_input_ids),
839
- "completion_tokens": len(req.output_ids),
840
- "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
841
- "finish_reason": (
842
- req.finished_reason.to_json()
843
- if req.finished_reason is not None
844
- else None
845
- ),
846
- }
847
- if req.return_logprob:
848
- (
849
- meta_info["input_token_logprobs"],
850
- meta_info["output_token_logprobs"],
851
- meta_info["input_top_logprobs"],
852
- meta_info["output_top_logprobs"],
853
- meta_info["normalized_prompt_logprob"],
854
- ) = (
855
- req.input_token_logprobs,
856
- req.output_token_logprobs,
857
- req.input_top_logprobs,
858
- req.output_top_logprobs,
859
- req.normalized_prompt_logprob,
860
- )
861
- output_meta_info.append(meta_info)
862
- else: # for embedding model
863
- output_embeddings.append(req.embedding)
864
- meta_info = {
865
- "prompt_tokens": len(req.origin_input_ids),
866
- }
867
- output_meta_info.append(meta_info)
868
-
869
- # Send to detokenizer
870
- if output_rids:
871
- if self.model_runner.is_generation:
872
- self.out_pyobjs.append(
873
- BatchTokenIDOut(
874
- output_rids,
875
- output_vids,
876
- decoded_texts,
877
- output_read_ids,
878
- output_read_offsets,
879
- output_skip_special_tokens,
880
- output_spaces_between_special_tokens,
881
- output_meta_info,
882
- output_finished_reason,
883
- )
884
- )
885
- else: # for embedding model
886
- self.out_pyobjs.append(
887
- BatchEmbeddingOut(
888
- output_rids,
889
- output_embeddings,
890
- output_meta_info,
891
- output_finished_reason,
892
- )
893
- )
894
-
895
- # Remove finished reqs: update batch tensors
896
- batch.filter_batch(unfinished_indices)
897
-
898
- def flush_cache(self):
899
- if len(self.waiting_queue) == 0 and (
900
- self.running_batch is None or len(self.running_batch.reqs) == 0
901
- ):
902
- self.tree_cache.reset()
903
- self.tree_cache_metrics = {"total": 0, "hit": 0}
904
- self.regex_fsm_cache.reset()
905
- self.req_to_token_pool.clear()
906
- self.token_to_kv_pool.clear()
907
- torch.cuda.empty_cache()
908
- logger.info("Cache flushed successfully!")
909
- if_success = True
910
- else:
911
- logging.warning(
912
- f"Cache not flushed because there are pending requests. "
913
- f"#queue-req: {len(self.waiting_queue)}, "
914
- f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
915
- )
916
- if_success = False
917
- return if_success
918
-
919
- def abort_request(self, recv_req):
920
- # Delete requests in the waiting queue
921
- to_del = None
922
- for i, req in enumerate(self.waiting_queue):
923
- if req.rid == recv_req.rid:
924
- to_del = i
925
- break
926
-
927
- if to_del is not None:
928
- del self.waiting_queue[to_del]
929
-
930
- # Delete requests in the running batch
931
- if self.running_batch:
932
- for req in self.running_batch.reqs:
933
- if req.rid == recv_req.rid:
934
- req.finished_reason = FINISH_ABORT()
935
- break
936
-
937
- def update_weights(self, recv_req):
124
+ def update_weights(self, recv_req: UpdateWeightReqInput):
938
125
  success, message = self.model_runner.update_weights(
939
126
  recv_req.model_path, recv_req.load_format
940
127
  )
941
- if success:
942
- flash_cache_success = self.flush_cache()
943
- assert flash_cache_success, "Cache flush failed after updating weights"
944
- else:
945
- logger.error(message)
946
128
  return success, message
947
-
948
-
949
- def run_tp_server(
950
- gpu_id: int,
951
- tp_rank: int,
952
- server_args: ServerArgs,
953
- nccl_port: int,
954
- ):
955
- """Run a tensor parallel model server."""
956
- configure_logger(server_args, prefix=f" TP{tp_rank}")
957
-
958
- try:
959
- model_server = ModelTpServer(
960
- gpu_id,
961
- tp_rank,
962
- server_args,
963
- nccl_port,
964
- )
965
- tp_cpu_group = model_server.model_runner.tp_group.cpu_group
966
-
967
- while True:
968
- recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
969
- model_server.exposed_step(recv_reqs)
970
- except Exception:
971
- logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
972
- raise
973
-
974
-
975
- def launch_tp_servers(
976
- gpu_ids: List[int],
977
- tp_rank_range: List[int],
978
- server_args: ServerArgs,
979
- nccl_port: int,
980
- ):
981
- """Launch multiple tensor parallel servers."""
982
- procs = []
983
- for i in tp_rank_range:
984
- proc = multiprocessing.Process(
985
- target=run_tp_server,
986
- args=(gpu_ids[i], i, server_args, nccl_port),
987
- )
988
- proc.start()
989
- procs.append(proc)
990
-
991
- return procs
992
-
993
-
994
- def broadcast_recv_input(
995
- data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
996
- ):
997
- """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
998
-
999
- if rank == 0:
1000
- if len(data) == 0:
1001
- tensor_size = torch.tensor([0], dtype=torch.long)
1002
- dist.broadcast(tensor_size, src=0, group=dist_group)
1003
- else:
1004
- serialized_data = pickle.dumps(data)
1005
- size = len(serialized_data)
1006
- tensor_data = torch.ByteTensor(list(serialized_data))
1007
- tensor_size = torch.tensor([size], dtype=torch.long)
1008
-
1009
- dist.broadcast(tensor_size, src=0, group=dist_group)
1010
- dist.broadcast(tensor_data, src=0, group=dist_group)
1011
- return data
1012
- else:
1013
- tensor_size = torch.tensor([0], dtype=torch.long)
1014
- dist.broadcast(tensor_size, src=0, group=dist_group)
1015
- size = tensor_size.item()
1016
-
1017
- if size == 0:
1018
- return []
1019
-
1020
- tensor_data = torch.empty(size, dtype=torch.uint8)
1021
- dist.broadcast(tensor_data, src=0, group=dist_group)
1022
-
1023
- serialized_data = bytes(tensor_data.tolist())
1024
- data = pickle.loads(serialized_data)
1025
- return data