sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +76 -15
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/fsm_cache.py +10 -3
  14. sglang/srt/constrained/grammar.py +190 -0
  15. sglang/srt/hf_transformers_utils.py +20 -5
  16. sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  18. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  19. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  20. sglang/srt/layers/fused_moe/layer.py +28 -0
  21. sglang/srt/layers/logits_processor.py +5 -5
  22. sglang/srt/layers/quantization/base_config.py +16 -1
  23. sglang/srt/layers/rotary_embedding.py +15 -48
  24. sglang/srt/layers/sampler.py +51 -39
  25. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  26. sglang/srt/managers/data_parallel_controller.py +8 -7
  27. sglang/srt/managers/detokenizer_manager.py +11 -9
  28. sglang/srt/managers/image_processor.py +4 -3
  29. sglang/srt/managers/io_struct.py +80 -78
  30. sglang/srt/managers/schedule_batch.py +46 -52
  31. sglang/srt/managers/schedule_policy.py +24 -13
  32. sglang/srt/managers/scheduler.py +145 -82
  33. sglang/srt/managers/tokenizer_manager.py +236 -334
  34. sglang/srt/managers/tp_worker.py +5 -5
  35. sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  36. sglang/srt/mem_cache/flush_cache.py +1 -1
  37. sglang/srt/mem_cache/memory_pool.py +10 -3
  38. sglang/srt/model_executor/cuda_graph_runner.py +34 -23
  39. sglang/srt/model_executor/forward_batch_info.py +6 -9
  40. sglang/srt/model_executor/model_runner.py +10 -19
  41. sglang/srt/models/baichuan.py +4 -4
  42. sglang/srt/models/chatglm.py +4 -4
  43. sglang/srt/models/commandr.py +1 -1
  44. sglang/srt/models/dbrx.py +5 -5
  45. sglang/srt/models/deepseek.py +4 -4
  46. sglang/srt/models/deepseek_v2.py +4 -4
  47. sglang/srt/models/exaone.py +4 -4
  48. sglang/srt/models/gemma.py +1 -1
  49. sglang/srt/models/gemma2.py +1 -1
  50. sglang/srt/models/gpt2.py +287 -0
  51. sglang/srt/models/gpt_bigcode.py +1 -1
  52. sglang/srt/models/grok.py +4 -4
  53. sglang/srt/models/internlm2.py +4 -4
  54. sglang/srt/models/llama.py +15 -7
  55. sglang/srt/models/llama_embedding.py +2 -10
  56. sglang/srt/models/llama_reward.py +5 -0
  57. sglang/srt/models/minicpm.py +4 -4
  58. sglang/srt/models/minicpm3.py +4 -4
  59. sglang/srt/models/mixtral.py +7 -5
  60. sglang/srt/models/mixtral_quant.py +4 -4
  61. sglang/srt/models/mllama.py +5 -5
  62. sglang/srt/models/olmo.py +4 -4
  63. sglang/srt/models/olmoe.py +4 -4
  64. sglang/srt/models/qwen.py +4 -4
  65. sglang/srt/models/qwen2.py +4 -4
  66. sglang/srt/models/qwen2_moe.py +4 -4
  67. sglang/srt/models/qwen2_vl.py +4 -8
  68. sglang/srt/models/stablelm.py +4 -4
  69. sglang/srt/models/torch_native_llama.py +4 -4
  70. sglang/srt/models/xverse.py +4 -4
  71. sglang/srt/models/xverse_moe.py +4 -4
  72. sglang/srt/openai_api/adapter.py +52 -66
  73. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  74. sglang/srt/sampling/sampling_batch_info.py +7 -13
  75. sglang/srt/sampling/sampling_params.py +5 -7
  76. sglang/srt/server.py +41 -33
  77. sglang/srt/server_args.py +34 -5
  78. sglang/srt/utils.py +40 -56
  79. sglang/test/run_eval.py +2 -0
  80. sglang/test/runners.py +2 -1
  81. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  82. sglang/test/test_utils.py +151 -6
  83. sglang/utils.py +62 -1
  84. sglang/version.py +1 -1
  85. sglang-0.3.5.dist-info/METADATA +344 -0
  86. sglang-0.3.5.dist-info/RECORD +152 -0
  87. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  88. sglang-0.3.4.post1.dist-info/METADATA +0 -900
  89. sglang-0.3.4.post1.dist-info/RECORD +0 -148
  90. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
22
22
 
23
23
  @staticmethod
24
24
  def get_input_positions(
25
- input_tokens: List[int],
25
+ input_tokens: torch.Tensor,
26
26
  image_grid_thw: Union[List[List[int]], torch.Tensor],
27
- video_grid_thw: Union[List[List[int]], torch.Tensor],
28
- image_token_id: int,
29
- video_token_id: int,
30
27
  vision_start_token_id: int,
31
- vision_end_token_id: int,
32
28
  spatial_merge_size: int,
33
29
  context_len: int = 0,
34
- extend_prefix_len: int = 0,
35
30
  ) -> Tuple[List[List[int]], int]:
36
31
  """Get mrope input positions and delta value."""
37
32
 
38
33
  if isinstance(image_grid_thw, torch.Tensor):
39
34
  image_grid_thw = image_grid_thw.tolist()
40
- if isinstance(video_grid_thw, torch.Tensor):
41
- video_grid_thw = video_grid_thw.tolist()
42
35
 
43
- input_tokens_tensor = torch.tensor(input_tokens)
44
36
  vision_start_indices = torch.argwhere(
45
- input_tokens_tensor == vision_start_token_id
37
+ input_tokens == vision_start_token_id
46
38
  ).squeeze(1)
47
- vision_tokens = input_tokens_tensor[vision_start_indices + 1]
48
- image_nums = (vision_tokens == image_token_id).sum()
49
- video_nums = (vision_tokens == video_token_id).sum()
39
+ image_indices = vision_start_indices + 1
40
+ image_nums = image_indices.shape[0]
50
41
  llm_pos_ids_list: list = []
51
42
 
52
43
  st = 0
53
- remain_images, remain_videos = image_nums, video_nums
54
-
55
- image_index, video_index = 0, 0
56
- for _ in range(image_nums + video_nums):
57
- if image_token_id in input_tokens and remain_images > 0:
58
- ed_image = input_tokens.index(image_token_id, st)
59
- else:
60
- ed_image = len(input_tokens) + 1
61
- if video_token_id in input_tokens and remain_videos > 0:
62
- ed_video = input_tokens.index(video_token_id, st)
63
- else:
64
- ed_video = len(input_tokens) + 1
65
- if ed_image < ed_video:
66
- t, h, w = (
67
- image_grid_thw[image_index][0],
68
- image_grid_thw[image_index][1],
69
- image_grid_thw[image_index][2],
70
- )
71
- image_index += 1
72
- remain_images -= 1
73
- ed = ed_image
74
- else:
75
- t, h, w = (
76
- video_grid_thw[video_index][0],
77
- video_grid_thw[video_index][1],
78
- video_grid_thw[video_index][2],
79
- )
80
- video_index += 1
81
- remain_videos -= 1
82
- ed = ed_video
44
+ input_tokens_len = input_tokens.shape[0]
45
+ for image_index in range(image_nums):
46
+ ed = image_indices[image_index].item()
47
+ t, h, w = (
48
+ image_grid_thw[image_index][0],
49
+ image_grid_thw[image_index][1],
50
+ image_grid_thw[image_index][2],
51
+ )
83
52
  llm_grid_t, llm_grid_h, llm_grid_w = (
84
53
  t,
85
54
  h // spatial_merge_size,
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
115
84
  )
116
85
  st = ed + llm_grid_t * llm_grid_h * llm_grid_w
117
86
 
118
- if st < len(input_tokens):
87
+ if st < input_tokens_len:
119
88
  st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
120
- text_len = len(input_tokens) - st
89
+ text_len = input_tokens_len - st
121
90
  llm_pos_ids_list.append(
122
91
  torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
123
92
  )
124
93
 
125
94
  llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
126
95
  llm_positions = llm_positions[:, context_len:]
127
- mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
128
- llm_positions += extend_prefix_len
129
-
96
+ mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
130
97
  return llm_positions.tolist(), mrope_position_delta
131
98
 
132
99
  @staticmethod
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import os
2
3
  from typing import Union
3
4
 
4
5
  import torch
@@ -17,6 +18,11 @@ if is_flashinfer_available():
17
18
  top_p_renorm_prob,
18
19
  )
19
20
 
21
+
22
+ # Crash on warning if we are running CI tests
23
+ crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
24
+
25
+
20
26
  logger = logging.getLogger(__name__)
21
27
 
22
28
 
@@ -33,56 +39,62 @@ class Sampler(nn.Module):
33
39
  if isinstance(logits, LogitsProcessorOutput):
34
40
  logits = logits.next_token_logits
35
41
 
36
- # Post process logits
37
42
  logits = logits.contiguous()
38
- logits.div_(sampling_info.temperatures)
39
- probs = torch.softmax(logits, dim=-1)
40
- logits = None
41
- del logits
42
-
43
- if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
44
- logger.warning("Detected errors during sampling! NaN in the probability.")
45
- probs = torch.where(
46
- torch.isnan(probs), torch.full_like(probs, 1e-10), probs
43
+
44
+ if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
45
+ logger.warning("Detected errors during sampling! NaN in the logits.")
46
+ logits = torch.where(
47
+ torch.isnan(logits), torch.full_like(logits, -1e5), logits
47
48
  )
49
+ exit(1) if crash_on_warning else None
48
50
 
49
51
  if sampling_info.is_all_greedy:
50
52
  # Use torch.argmax if all requests use greedy sampling
51
- batch_next_token_ids = torch.argmax(probs, -1)
52
- elif global_server_args_dict["sampling_backend"] == "flashinfer":
53
- max_top_k_round, batch_size = 32, probs.shape[0]
54
- uniform_samples = torch.rand(
55
- (max_top_k_round, batch_size), device=probs.device
56
- )
57
- if sampling_info.need_min_p_sampling:
58
- probs = top_k_renorm_prob(probs, sampling_info.top_ks)
59
- probs = top_p_renorm_prob(probs, sampling_info.top_ps)
60
- batch_next_token_ids, success = min_p_sampling_from_probs(
61
- probs, uniform_samples, sampling_info.min_ps
53
+ batch_next_token_ids = torch.argmax(logits, -1)
54
+ else:
55
+ # Post process logits
56
+ logits.div_(sampling_info.temperatures)
57
+ probs = torch.softmax(logits, dim=-1)
58
+ logits = None
59
+ del logits
60
+
61
+ if global_server_args_dict["sampling_backend"] == "flashinfer":
62
+ max_top_k_round, batch_size = 32, probs.shape[0]
63
+ uniform_samples = torch.rand(
64
+ (max_top_k_round, batch_size), device=probs.device
62
65
  )
63
- else:
64
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
66
+ if sampling_info.need_min_p_sampling:
67
+ probs = top_k_renorm_prob(probs, sampling_info.top_ks)
68
+ probs = top_p_renorm_prob(probs, sampling_info.top_ps)
69
+ batch_next_token_ids, success = min_p_sampling_from_probs(
70
+ probs, uniform_samples, sampling_info.min_ps
71
+ )
72
+ else:
73
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
74
+ probs,
75
+ uniform_samples,
76
+ sampling_info.top_ks,
77
+ sampling_info.top_ps,
78
+ filter_apply_order="joint",
79
+ )
80
+
81
+ if not torch.all(success):
82
+ logger.warning("Detected errors during sampling!")
83
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
84
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
85
+ # A slower fallback implementation with torch native operations.
86
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
65
87
  probs,
66
- uniform_samples,
67
88
  sampling_info.top_ks,
68
89
  sampling_info.top_ps,
69
- filter_apply_order="joint",
90
+ sampling_info.min_ps,
91
+ )
92
+ else:
93
+ raise ValueError(
94
+ f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
70
95
  )
71
96
 
72
- if not torch.all(success):
73
- logger.warning("Detected errors during sampling!")
74
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
75
- elif global_server_args_dict["sampling_backend"] == "pytorch":
76
- # Here we provide a slower fallback implementation.
77
- batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
78
- probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
79
- )
80
- else:
81
- raise ValueError(
82
- f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
83
- )
84
-
85
- return batch_next_token_ids
97
+ return batch_next_token_ids.to(torch.int32)
86
98
 
87
99
 
88
100
  def top_k_top_p_min_p_sampling_from_probs_torch(