sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
28
- from flashinfer import (
29
- BatchDecodeWithPagedKVCacheWrapper,
30
- BatchPrefillWithPagedKVCacheWrapper,
31
- BatchPrefillWithRaggedKVCacheWrapper,
32
- )
33
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
28
  from vllm.config import DeviceConfig, LoadConfig
35
29
  from vllm.config import ModelConfig as VllmModelConfig
36
30
  from vllm.distributed import (
@@ -43,32 +37,34 @@ from vllm.distributed.parallel_state import in_the_same_node_as
43
37
  from vllm.model_executor.model_loader import get_model
44
38
  from vllm.model_executor.models import ModelRegistry
45
39
 
46
- from sglang.global_config import global_config
40
+ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
+ from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
47
42
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
- from sglang.srt.layers.sampler import SampleOutput
43
+ from sglang.srt.layers.sampler import Sampler
44
+ from sglang.srt.lora.lora_manager import LoRAManager
49
45
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
50
46
  from sglang.srt.mem_cache.memory_pool import (
51
47
  MHATokenToKVPool,
52
48
  MLATokenToKVPool,
53
49
  ReqToTokenPool,
54
50
  )
55
- from sglang.srt.model_config import AttentionArch, ModelConfig
56
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
51
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
57
53
  from sglang.srt.server_args import ServerArgs
58
54
  from sglang.srt.utils import (
59
55
  get_available_gpu_memory,
60
56
  is_generation_model,
61
- is_llama3_405b_fp8_head_16,
62
57
  is_multimodal_model,
63
58
  monkey_patch_vllm_dummy_weight_loader,
64
59
  monkey_patch_vllm_p2p_access_check,
65
- monkey_patch_vllm_qvk_linear_loader,
66
60
  )
67
61
 
68
62
  logger = logging.getLogger(__name__)
69
63
 
70
64
 
71
65
  class ModelRunner:
66
+ """ModelRunner runs the forward passes of the models."""
67
+
72
68
  def __init__(
73
69
  self,
74
70
  model_config: ModelConfig,
@@ -92,13 +88,15 @@ class ModelRunner:
92
88
  )
93
89
  global_server_args_dict.update(
94
90
  {
95
- "disable_flashinfer": server_args.disable_flashinfer,
96
- "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
91
+ "attention_backend": server_args.attention_backend,
92
+ "sampling_backend": server_args.sampling_backend,
97
93
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
98
94
  "enable_mla": server_args.enable_mla,
95
+ "torchao_config": server_args.torchao_config,
99
96
  }
100
97
  )
101
98
 
99
+ # Model-specific adjustment
102
100
  if self.is_multimodal_model:
103
101
  logger.info(
104
102
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
@@ -106,15 +104,19 @@ class ModelRunner:
106
104
  server_args.chunked_prefill_size = None
107
105
  server_args.mem_fraction_static *= 0.95
108
106
 
107
+ # Init componnets
109
108
  min_per_gpu_memory = self.init_torch_distributed()
109
+ self.sampler = Sampler()
110
110
  self.load_model()
111
+ if server_args.lora_paths is not None:
112
+ self.init_lora_manager()
111
113
  self.init_memory_pool(
112
114
  min_per_gpu_memory,
113
- server_args.max_num_reqs,
115
+ server_args.max_running_requests,
114
116
  server_args.max_total_tokens,
115
117
  )
116
118
  self.init_cublas()
117
- self.init_flashinfer()
119
+ self.init_attention_backend()
118
120
  self.init_cuda_graphs()
119
121
 
120
122
  def init_torch_distributed(self):
@@ -162,10 +164,13 @@ class ModelRunner:
162
164
  return min_per_gpu_memory
163
165
 
164
166
  def load_model(self):
165
- torch.set_num_threads(1)
166
167
  logger.info(
167
168
  f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
168
169
  )
170
+
171
+ # This can reduce thread conflicts and speed up weight loading.
172
+ torch.set_num_threads(1)
173
+
169
174
  if torch.cuda.get_device_capability()[0] < 8:
170
175
  logger.info(
171
176
  "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
@@ -174,6 +179,7 @@ class ModelRunner:
174
179
  if torch.cuda.get_device_capability()[1] < 5:
175
180
  raise RuntimeError("SGLang only supports sm75 and above.")
176
181
 
182
+ # Prepare the vllm model config
177
183
  monkey_patch_vllm_dummy_weight_loader()
178
184
  self.device_config = DeviceConfig()
179
185
  self.load_config = LoadConfig(load_format=self.server_args.load_format)
@@ -184,23 +190,16 @@ class ModelRunner:
184
190
  tokenizer_mode=None,
185
191
  trust_remote_code=self.server_args.trust_remote_code,
186
192
  dtype=self.server_args.dtype,
187
- seed=42,
193
+ seed=self.server_args.random_seed,
188
194
  skip_tokenizer_init=True,
189
195
  )
190
-
191
- # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
192
- # Drop this after Sept, 2024.
193
- if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
194
- self.model_config.hf_config.num_key_value_heads = 8
195
- self.vllm_model_config.hf_config.num_key_value_heads = 8
196
- monkey_patch_vllm_qvk_linear_loader()
197
-
198
- self.dtype = self.vllm_model_config.dtype
199
196
  if self.model_config.model_override_args is not None:
200
197
  self.vllm_model_config.hf_config.update(
201
198
  self.model_config.model_override_args
202
199
  )
200
+ self.dtype = self.vllm_model_config.dtype
203
201
 
202
+ # Load the model
204
203
  self.model = get_model(
205
204
  model_config=self.vllm_model_config,
206
205
  load_config=self.load_config,
@@ -251,20 +250,20 @@ class ModelRunner:
251
250
  tokenizer_mode=None,
252
251
  trust_remote_code=self.server_args.trust_remote_code,
253
252
  dtype=self.server_args.dtype,
254
- seed=42,
253
+ seed=self.server_args.random_seed,
255
254
  skip_tokenizer_init=True,
256
255
  )
257
256
  except Exception as e:
258
- logger.error(f"Failed to load model config: {e}")
259
- return False, "Failed to update model weights"
257
+ message = f"Failed to load model config: {e}."
258
+ return False, message
260
259
 
261
260
  load_config = LoadConfig(load_format=load_format)
262
261
 
263
262
  # Only support vllm DefaultModelLoader for now
264
263
  loader = get_model_loader(load_config)
265
264
  if not isinstance(loader, DefaultModelLoader):
266
- logger.error("Failed to get weights iterator: Unsupported loader")
267
- return False, "Failed to update model weights"
265
+ message = f"Failed to get model loader: {loader}."
266
+ return False, message
268
267
 
269
268
  def get_weight_iter(config):
270
269
  iter = loader._get_weights_iterator(
@@ -289,14 +288,14 @@ class ModelRunner:
289
288
  try:
290
289
  iter = get_weight_iter(vllm_model_config)
291
290
  except Exception as e:
292
- message = f"Failed to get weights iterator: {e}"
293
- logger.error(message)
291
+ message = f"Failed to get weights iterator: {e}."
294
292
  return False, message
295
293
  try:
296
294
  model = model_load_weights(self.model, iter)
297
295
  except Exception as e:
298
- message = f"Failed to update weights: {e}. \n Rolling back to original weights"
299
- logger.error(message)
296
+ message = (
297
+ f"Failed to update weights: {e}.\nRolling back to original weights."
298
+ )
300
299
  del iter
301
300
  gc.collect()
302
301
  iter = get_weight_iter(self.vllm_model_config)
@@ -311,7 +310,18 @@ class ModelRunner:
311
310
  self.model_config.path = model_path
312
311
 
313
312
  logger.info("Update weights end.")
314
- return True, "Succeeded to update model weights"
313
+ return True, "Succeeded to update model weights."
314
+
315
+ def init_lora_manager(self):
316
+ self.lora_manager = LoRAManager(
317
+ base_model=self.model,
318
+ lora_paths=self.server_args.lora_paths,
319
+ base_hf_config=self.model_config.hf_config,
320
+ max_loras_per_batch=self.server_args.max_loras_per_batch,
321
+ load_config=self.load_config,
322
+ dtype=self.dtype,
323
+ )
324
+ logger.info("LoRA manager ready.")
315
325
 
316
326
  def profile_max_num_token(self, total_gpu_memory: int):
317
327
  available_gpu_memory = get_available_gpu_memory(
@@ -343,8 +353,8 @@ class ModelRunner:
343
353
  def init_memory_pool(
344
354
  self,
345
355
  total_gpu_memory: int,
346
- max_num_reqs: int = None,
347
- max_total_tokens: int = None,
356
+ max_num_reqs: Optional[int] = None,
357
+ max_total_tokens: Optional[int] = None,
348
358
  ):
349
359
  if self.server_args.kv_cache_dtype == "auto":
350
360
  self.kv_cache_dtype = self.dtype
@@ -378,7 +388,7 @@ class ModelRunner:
378
388
  ),
379
389
  2048,
380
390
  ),
381
- 5120,
391
+ 4096,
382
392
  )
383
393
 
384
394
  self.req_to_token_pool = ReqToTokenPool(
@@ -396,9 +406,6 @@ class ModelRunner:
396
406
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
397
407
  layer_num=self.model_config.num_hidden_layers,
398
408
  )
399
- logger.info("using MLA Triton implementaion, flashinfer is disabled")
400
- # FIXME: temporarily only Triton MLA is supported
401
- self.server_args.disable_flashinfer = True
402
409
  else:
403
410
  self.token_to_kv_pool = MHATokenToKVPool(
404
411
  self.max_total_num_tokens,
@@ -421,118 +428,46 @@ class ModelRunner:
421
428
  c = a @ b
422
429
  return c
423
430
 
424
- def init_flashinfer(self):
425
- """Init flashinfer attention kernel wrappers."""
426
- if self.server_args.disable_flashinfer:
427
- assert (
428
- self.sliding_window_size is None
429
- ), "turn on flashinfer to support window attention"
430
- self.flashinfer_prefill_wrapper_ragged = None
431
- self.flashinfer_prefill_wrapper_paged = None
432
- self.flashinfer_decode_wrapper = None
433
- return
434
-
435
- if not _grouped_size_compiled_for_decode_kernels(
436
- self.model_config.num_attention_heads // self.tp_size,
437
- self.model_config.get_num_kv_heads(self.tp_size),
438
- ):
439
- use_tensor_cores = True
440
- else:
441
- use_tensor_cores = False
442
-
443
- if self.sliding_window_size is None:
444
- self.flashinfer_workspace_buffer = torch.empty(
445
- global_config.flashinfer_workspace_size,
446
- dtype=torch.uint8,
447
- device="cuda",
448
- )
449
- self.flashinfer_prefill_wrapper_ragged = (
450
- BatchPrefillWithRaggedKVCacheWrapper(
451
- self.flashinfer_workspace_buffer, "NHD"
452
- )
453
- )
454
- self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
455
- self.flashinfer_workspace_buffer, "NHD"
456
- )
457
- self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
458
- self.flashinfer_workspace_buffer,
459
- "NHD",
460
- use_tensor_cores=use_tensor_cores,
431
+ def init_attention_backend(self):
432
+ """Init attention kernel backend."""
433
+ if self.server_args.attention_backend == "flashinfer":
434
+ self.attn_backend = FlashInferAttnBackend(self)
435
+ elif self.server_args.attention_backend == "triton":
436
+ assert self.sliding_window_size is None, (
437
+ "Window attention is not supported in the triton attention backend. "
438
+ "Please use `--attention-backend flashinfer`."
461
439
  )
440
+ self.attn_backend = TritonAttnBackend(self)
462
441
  else:
463
- self.flashinfer_workspace_buffer = torch.empty(
464
- global_config.flashinfer_workspace_size,
465
- dtype=torch.uint8,
466
- device="cuda",
442
+ raise ValueError(
443
+ f"Invalid attention backend: {self.server_args.attention_backend}"
467
444
  )
468
- self.flashinfer_prefill_wrapper_ragged = None
469
- self.flashinfer_prefill_wrapper_paged = []
470
- self.flashinfer_decode_wrapper = []
471
- for i in range(2):
472
- self.flashinfer_prefill_wrapper_paged.append(
473
- BatchPrefillWithPagedKVCacheWrapper(
474
- self.flashinfer_workspace_buffer, "NHD"
475
- )
476
- )
477
- self.flashinfer_decode_wrapper.append(
478
- BatchDecodeWithPagedKVCacheWrapper(
479
- self.flashinfer_workspace_buffer,
480
- "NHD",
481
- use_tensor_cores=use_tensor_cores,
482
- )
483
- )
484
445
 
485
446
  def init_cuda_graphs(self):
486
447
  """Capture cuda graphs."""
448
+ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
449
+
450
+ self.cuda_graph_runner = None
451
+
487
452
  if not self.is_generation:
488
453
  # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
489
454
  return
490
455
 
491
- from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
492
-
493
- if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
494
- self.cuda_graph_runner = None
456
+ if self.server_args.disable_cuda_graph:
495
457
  return
496
458
 
497
459
  logger.info("Capture cuda graph begin. This can take up to several minutes.")
498
-
499
- if self.server_args.disable_cuda_graph_padding:
500
- batch_size_list = list(range(1, 32)) + [64, 128]
501
- else:
502
- batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
503
-
504
- self.cuda_graph_runner = CudaGraphRunner(
505
- self,
506
- max_batch_size_to_capture=max(batch_size_list),
507
- use_torch_compile=self.server_args.enable_torch_compile,
508
- disable_padding=self.server_args.disable_cuda_graph_padding,
509
- )
510
- try:
511
- self.cuda_graph_runner.capture(batch_size_list)
512
- except RuntimeError as e:
513
- raise Exception(
514
- f"Capture cuda graph failed: {e}\n"
515
- "Possible solutions:\n"
516
- "1. disable cuda graph by --disable-cuda-graph\n"
517
- "2. set --mem-fraction-static to a smaller value\n"
518
- "3. disable torch compile by not using --enable-torch-compile\n"
519
- "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
520
- )
460
+ self.cuda_graph_runner = CudaGraphRunner(self)
521
461
 
522
462
  @torch.inference_mode()
523
463
  def forward_decode(self, batch: ScheduleBatch):
524
- if (
525
- self.cuda_graph_runner
526
- and self.cuda_graph_runner.can_run(len(batch.reqs))
527
- and batch.sampling_info.can_run_in_cuda_graph()
528
- ):
464
+ if self.server_args.lora_paths is not None:
465
+ self.lora_manager.prepare_lora_batch(batch)
466
+
467
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
529
468
  return self.cuda_graph_runner.replay(batch)
530
469
 
531
- input_metadata = InputMetadata.from_schedule_batch(
532
- self,
533
- batch,
534
- ForwardMode.DECODE,
535
- )
470
+ input_metadata = InputMetadata.from_schedule_batch(self, batch)
536
471
 
537
472
  return self.model.forward(
538
473
  batch.input_ids, input_metadata.positions, input_metadata
@@ -540,11 +475,10 @@ class ModelRunner:
540
475
 
541
476
  @torch.inference_mode()
542
477
  def forward_extend(self, batch: ScheduleBatch):
543
- input_metadata = InputMetadata.from_schedule_batch(
544
- self,
545
- batch,
546
- forward_mode=ForwardMode.EXTEND,
547
- )
478
+ input_metadata = InputMetadata.from_schedule_batch(self, batch)
479
+ if self.server_args.lora_paths is not None:
480
+ self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
481
+
548
482
  if self.is_generation:
549
483
  return self.model.forward(
550
484
  batch.input_ids, input_metadata.positions, input_metadata
@@ -560,11 +494,7 @@ class ModelRunner:
560
494
 
561
495
  @torch.inference_mode()
562
496
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
563
- input_metadata = InputMetadata.from_schedule_batch(
564
- self,
565
- batch,
566
- forward_mode=ForwardMode.EXTEND,
567
- )
497
+ input_metadata = InputMetadata.from_schedule_batch(self, batch)
568
498
  return self.model.forward(
569
499
  batch.input_ids,
570
500
  input_metadata.positions,
@@ -574,17 +504,56 @@ class ModelRunner:
574
504
  input_metadata.image_offsets,
575
505
  )
576
506
 
577
- def forward(
578
- self, batch: ScheduleBatch, forward_mode: ForwardMode
579
- ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
580
- if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
507
+ def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
508
+ assert batch.forward_mode is not None
509
+
510
+ if self.is_multimodal_model and batch.forward_mode.is_extend():
581
511
  return self.forward_extend_multi_modal(batch)
582
- elif forward_mode == ForwardMode.DECODE:
512
+ elif batch.forward_mode.is_decode():
583
513
  return self.forward_decode(batch)
584
- elif forward_mode == ForwardMode.EXTEND:
514
+ elif batch.forward_mode.is_extend():
585
515
  return self.forward_extend(batch)
586
516
  else:
587
- raise ValueError(f"Invaid forward mode: {forward_mode}")
517
+ raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
518
+
519
+ def _apply_logits_bias(
520
+ self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
521
+ ):
522
+ # Apply logit_bias
523
+ if sampling_info.logit_bias is not None:
524
+ logits.add_(sampling_info.logit_bias)
525
+
526
+ # min-token, presence, frequency
527
+ if sampling_info.linear_penalties is not None:
528
+ logits += sampling_info.linear_penalties
529
+
530
+ # repetition
531
+ if sampling_info.scaling_penalties is not None:
532
+ logits = torch.where(
533
+ logits > 0,
534
+ logits / sampling_info.scaling_penalties,
535
+ logits * sampling_info.scaling_penalties,
536
+ )
537
+
538
+ # Apply regex vocab_mask
539
+ if sampling_info.vocab_mask is not None:
540
+ logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
541
+
542
+ return logits
543
+
544
+ def sample(
545
+ self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
546
+ ) -> torch.Tensor:
547
+ # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
548
+ batch.sampling_info.update_regex_vocab_mask(batch)
549
+ batch.sampling_info.update_penalties()
550
+ logits = self._apply_logits_bias(
551
+ logits_output.next_token_logits, batch.sampling_info
552
+ )
553
+
554
+ # Sample the next tokens.
555
+ next_token_ids = self.sampler(logits, batch.sampling_info)
556
+ return next_token_ids
588
557
 
589
558
 
590
559
  @lru_cache()