sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Optional, Set, Tuple
1
+ from typing import Callable, Dict, List, Optional, Set, Tuple
2
2
 
3
3
  import torch
4
4
 
@@ -22,21 +22,16 @@ class LoRAMemoryPool:
22
22
  self,
23
23
  base_hf_config: AutoConfig,
24
24
  max_loras_per_batch: int,
25
- max_lora_dim: int,
26
25
  dtype: torch.dtype,
27
26
  tp_size: int,
28
27
  tp_rank: int,
29
- lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
30
28
  ):
31
-
32
29
  self.base_hf_config: AutoConfig = base_hf_config
33
30
  self.num_layer: int = base_hf_config.num_hidden_layers
34
31
  self.max_loras_per_batch: int = max_loras_per_batch
35
- self.max_lora_dim: int = max_lora_dim
36
32
  self.dtype: torch.dtype = dtype
37
33
  self.tp_size: int = tp_size
38
34
  self.tp_rank: int = tp_rank
39
- self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
40
35
 
41
36
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
42
37
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -55,89 +50,95 @@ class LoRAMemoryPool:
55
50
  self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
56
51
 
57
52
  def get_lora_A_shape(
58
- self, module_name: str, base_model: torch.nn.Module
53
+ self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
59
54
  ) -> Tuple[int]:
60
55
  """
61
56
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
62
57
  """
63
58
  input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
64
59
  c = get_stacked_multiply(module_name)
65
- if self.tp_size > 1:
66
- if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
67
- input_dim = divide(input_dim, self.tp_size)
60
+ if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
61
+ input_dim = divide(input_dim, self.tp_size)
68
62
  return (
69
63
  self.max_loras_per_batch,
70
- self.max_lora_dim * c,
64
+ max_lora_dim * c,
71
65
  input_dim,
72
66
  )
73
67
 
74
68
  def get_lora_B_shape(
75
- self, module_name: str, base_model: torch.nn.Module
69
+ self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
76
70
  ) -> Tuple[int]:
77
71
  """
78
72
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
79
73
  """
80
74
  _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
81
75
  c = get_stacked_multiply(module_name)
82
- if self.tp_size > 1:
83
- if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
84
- output_dim = divide(output_dim, self.tp_size)
76
+ if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
77
+ output_dim = divide(output_dim, self.tp_size)
85
78
  return (
86
79
  c,
87
80
  self.max_loras_per_batch,
88
81
  output_dim,
89
- self.max_lora_dim,
82
+ max_lora_dim,
90
83
  )
91
84
 
92
85
  def init_buffers(
93
86
  self,
94
87
  lora_weight_names: Tuple[Set[str]],
95
88
  base_model: torch.nn.Module,
89
+ max_lora_dim: int,
96
90
  ):
97
-
98
91
  # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
99
92
  # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
100
93
  self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
101
94
  device = next(base_model.parameters()).device
102
- # Init A tensor, column_major=False
103
- for module_A in lora_weight_names[0]:
104
- lora_A_shape = self.get_lora_A_shape(module_A, base_model)
105
- self.A_buffer[module_A] = [
106
- torch.empty(
107
- lora_A_shape,
108
- dtype=self.dtype,
109
- device=device,
110
- )
111
- for _ in range(self.num_layer)
112
- ]
113
- # Init B tensor, column_major=True
114
- for module_B in lora_weight_names[1]:
115
- lora_B_shape = self.get_lora_B_shape(module_B, base_model)
116
- self.B_buffer[module_B] = [
117
- torch.empty(
118
- lora_B_shape,
119
- dtype=self.dtype,
120
- device=device,
121
- )
122
- for _ in range(self.num_layer)
123
- ]
95
+
96
+ def update_buffer(
97
+ buffer: Dict[str, List[torch.Tensor]],
98
+ lora_weight_names: Set[str],
99
+ get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
100
+ ):
101
+ new_weight_names = lora_weight_names - buffer.keys()
102
+ for module_name in new_weight_names:
103
+ lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
104
+ buffer[module_name] = [
105
+ torch.empty(
106
+ lora_shape,
107
+ dtype=self.dtype,
108
+ device=device,
109
+ )
110
+ for _ in range(self.num_layer)
111
+ ]
112
+
113
+ update_buffer(
114
+ self.A_buffer,
115
+ lora_weight_names[0],
116
+ self.get_lora_A_shape,
117
+ )
118
+
119
+ update_buffer(
120
+ self.B_buffer,
121
+ lora_weight_names[1],
122
+ self.get_lora_B_shape,
123
+ )
124
124
 
125
125
  def prepare_lora_batch(
126
126
  self,
127
127
  cur_uids: Set[Optional[str]],
128
128
  lora_adapters: Dict[str, LoRAAdapter],
129
+ lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
129
130
  ):
130
-
131
131
  def get_available_buffer_slot():
132
132
  for buffer_id in range(self.max_loras_per_batch):
133
133
  # Prioritize empty slots
134
134
  if self.buffer_id_to_uid[buffer_id] == "":
135
- return buffer_id, ""
135
+ return buffer_id
136
136
 
137
137
  for buffer_id in range(self.max_loras_per_batch):
138
138
  # Evict unneeded lora
139
139
  if self.buffer_id_to_uid[buffer_id] not in cur_uids:
140
- return buffer_id, self.buffer_id_to_uid[buffer_id]
140
+ self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
141
+ return buffer_id
141
142
 
142
143
  raise ValueError(
143
144
  "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
@@ -145,17 +146,20 @@ class LoRAMemoryPool:
145
146
 
146
147
  for uid in cur_uids:
147
148
  if uid not in self.uid_to_buffer_id:
148
- buffer_id, evicted_lora_uid = get_available_buffer_slot()
149
- if evicted_lora_uid != "":
150
- self.uid_to_buffer_id.pop(evicted_lora_uid)
149
+ buffer_id = get_available_buffer_slot()
150
+ lora_adapter = lora_adapters.get(uid, None)
151
151
  self.load_lora_weight_to_buffer(
152
- uid, buffer_id, lora_adapters.get(uid, None)
152
+ uid, buffer_id, lora_adapter, lora_modules
153
153
  )
154
154
  self.uid_to_buffer_id[uid] = buffer_id
155
155
  self.buffer_id_to_uid[buffer_id] = uid
156
156
 
157
157
  def load_lora_weight_to_buffer(
158
- self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
158
+ self,
159
+ uid: str,
160
+ buffer_id: int,
161
+ lora_adapter: LoRAAdapter,
162
+ lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
159
163
  ):
160
164
  def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
161
165
  assert (
@@ -187,8 +191,8 @@ class LoRAMemoryPool:
187
191
  temp_B_buffer[lora_weight_name] = weights
188
192
 
189
193
  if self.tp_size > 1:
190
- cur_layer_modules = self.lora_modules[layer_id]
191
- for module_name, module in cur_layer_modules:
194
+ cur_layer_modules = lora_modules[layer_id]
195
+ for module_name, module in cur_layer_modules.items():
192
196
  if "qkv_proj" in module_name:
193
197
  temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
194
198
  temp_A_buffer["qkv_proj"], self.tp_rank
@@ -237,7 +241,6 @@ class LoRAMemoryPool:
237
241
  def get_tensor(
238
242
  self, weight_name: str, layer_id: int, lora_type: LoRAType
239
243
  ) -> torch.Tensor:
240
-
241
244
  if lora_type == LoRAType.LORA_A:
242
245
  return self.A_buffer[weight_name][layer_id]
243
246
 
sglang/srt/lora/utils.py CHANGED
@@ -108,7 +108,7 @@ def get_hidden_dim(
108
108
 
109
109
  def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
110
110
  """
111
- Mapping a target module name to names of the normized LoRA weights.
111
+ Mapping a target module name to names of the normalized LoRA weights.
112
112
  Returned tuple contains (name for Lora A, name for Lora B)
113
113
  """
114
114
  params_mapping = {
@@ -18,33 +18,50 @@ import logging
18
18
  import math
19
19
  import threading
20
20
  from queue import Empty, Full, PriorityQueue, Queue
21
- from typing import List, Optional
21
+ from typing import TYPE_CHECKING, List, Optional
22
22
 
23
23
  import torch
24
24
 
25
- from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
27
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
26
28
 
27
29
  logger = logging.getLogger(__name__)
28
30
 
29
31
 
30
32
  class LayerDoneCounter:
31
33
  def __init__(self, num_layers):
32
- self.counter = num_layers
33
- self.condition = threading.Condition()
34
+ self.num_layers = num_layers
35
+ # extra producer and consumer counters for overlap mode
36
+ self.num_counters = 3
37
+ self.counters = [num_layers] * self.num_counters
38
+ self.conditions = [threading.Condition() for _ in range(self.num_counters)]
39
+ self.producer_index = 0
40
+ self.consumer_index = 0
41
+
42
+ def next_producer(self):
43
+ return (self.producer_index + 1) % self.num_counters
44
+
45
+ def update_producer(self):
46
+ self.producer_index = self.next_producer()
47
+ return self.producer_index
48
+
49
+ def set_consumer(self, index):
50
+ self.consumer_index = index
34
51
 
35
52
  def increment(self):
36
- with self.condition:
37
- self.counter += 1
38
- self.condition.notify_all()
53
+ with self.conditions[self.producer_index]:
54
+ self.counters[self.producer_index] += 1
55
+ self.conditions[self.producer_index].notify_all()
39
56
 
40
57
  def wait_until(self, threshold):
41
- with self.condition:
42
- while self.counter <= threshold:
43
- self.condition.wait()
58
+ with self.conditions[self.consumer_index]:
59
+ while self.counters[self.consumer_index] <= threshold:
60
+ self.conditions[self.consumer_index].wait()
44
61
 
45
62
  def reset(self):
46
- with self.condition:
47
- self.counter = 0
63
+ with self.conditions[self.producer_index]:
64
+ self.counters[self.producer_index] = 0
48
65
 
49
66
 
50
67
  class CacheOperation:
@@ -147,7 +164,7 @@ class HiCacheController:
147
164
 
148
165
  def __init__(
149
166
  self,
150
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
167
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
151
168
  mem_pool_host: HostKVCache,
152
169
  page_size: int,
153
170
  load_cache_event: threading.Event = None,
@@ -295,7 +312,6 @@ class HiCacheController:
295
312
  while not self.stop_event.is_set():
296
313
  try:
297
314
  operation = self.load_queue.get(block=True, timeout=1)
298
- # time.sleep(18e-6 * len(operation.host_indices))
299
315
  operation.data = self.mem_pool_host.get_flat_data(
300
316
  operation.host_indices
301
317
  )
@@ -319,6 +335,7 @@ class HiCacheController:
319
335
  if not self.load_cache_event.is_set():
320
336
  continue
321
337
  self.load_cache_event.clear()
338
+ self.layer_done_counter.update_producer()
322
339
 
323
340
  batch_operation = None
324
341
  while self.load_queue.qsize() > 0:
@@ -330,6 +347,7 @@ class HiCacheController:
330
347
  if batch_operation is None:
331
348
  continue
332
349
 
350
+ # start layer-wise KV cache transfer from CPU to GPU
333
351
  self.layer_done_counter.reset()
334
352
  for i in range(self.mem_pool_host.layer_num):
335
353
  if self.page_size == 1:
@@ -465,6 +483,7 @@ class HiCacheController:
465
483
  except Exception as e:
466
484
  logger.error(e)
467
485
 
486
+ # todo (zhiqiang): double buffering to be deprecated
468
487
  def write_thread_func_buffer(self):
469
488
  aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
470
489
  aux_thread.start()
@@ -87,7 +87,7 @@ class GenerateReqInput:
87
87
 
88
88
  # The modalities of the image data [image, multi-images, video]
89
89
  modalities: Optional[List[str]] = None
90
- # LoRA related
90
+ # The path to the LoRA
91
91
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
92
92
 
93
93
  # Session info for continual prompting
@@ -99,7 +99,7 @@ class GenerateReqInput:
99
99
  custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
100
100
 
101
101
  # Whether to return hidden states
102
- return_hidden_states: bool = False
102
+ return_hidden_states: Union[List[bool], bool] = False
103
103
 
104
104
  # For disaggregated inference
105
105
  bootstrap_host: Optional[Union[List[str], str]] = None
@@ -226,11 +226,11 @@ class GenerateReqInput:
226
226
 
227
227
  # Expand input based on type
228
228
  self._expand_inputs(num)
229
+ self._normalize_rid(num)
229
230
  self._normalize_lora_paths(num)
230
231
  self._normalize_image_data(num)
231
232
  self._normalize_audio_data(num)
232
233
  self._normalize_sampling_params(num)
233
- self._normalize_rid(num)
234
234
  self._normalize_logprob_params(num)
235
235
  self._normalize_custom_logit_processor(num)
236
236
 
@@ -409,7 +409,11 @@ class GenerateReqInput:
409
409
  if self.custom_logit_processor is not None
410
410
  else None
411
411
  ),
412
- return_hidden_states=self.return_hidden_states,
412
+ return_hidden_states=(
413
+ self.return_hidden_states[i]
414
+ if isinstance(self.return_hidden_states, list)
415
+ else self.return_hidden_states
416
+ ),
413
417
  # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
414
418
  bootstrap_host=(
415
419
  self.bootstrap_host[i] if self.bootstrap_host is not None else None
@@ -477,7 +481,7 @@ class TokenizedGenerateReqInput:
477
481
  @dataclass
478
482
  class EmbeddingReqInput:
479
483
  # The input prompt. It can be a single prompt or a batch of prompts.
480
- text: Optional[Union[List[str], str]] = None
484
+ text: Optional[Union[List[List[str]], List[str], str]] = None
481
485
  # The image input. It can be an image instance, file name, URL, or base64 encoded string.
482
486
  # Can be formatted as:
483
487
  # - Single image for a single request
@@ -501,6 +505,8 @@ class EmbeddingReqInput:
501
505
  log_metrics: bool = True
502
506
  # The modalities of the image data [image, multi-images, video]
503
507
  modalities: Optional[List[str]] = None
508
+ # For cross-encoder requests
509
+ is_cross_encoder_request: bool = False
504
510
 
505
511
  def contains_mm_input(self) -> bool:
506
512
  return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
@@ -524,6 +530,7 @@ class EmbeddingReqInput:
524
530
  if self.text is not None:
525
531
  if isinstance(self.text, list):
526
532
  self.batch_size += len(self.text)
533
+ self.is_single = False
527
534
  else:
528
535
  self.batch_size += 1
529
536
 
@@ -531,12 +538,10 @@ class EmbeddingReqInput:
531
538
  if self.input_ids is not None:
532
539
  if isinstance(self.input_ids[0], list):
533
540
  self.batch_size += len(self.input_ids)
541
+ self.is_single = False
534
542
  else:
535
543
  self.batch_size += 1
536
544
 
537
- if self.batch_size > 1:
538
- self.is_single = False
539
-
540
545
  # Fill in default arguments
541
546
  if self.is_single:
542
547
  if self.rid is None:
@@ -560,6 +565,16 @@ class EmbeddingReqInput:
560
565
  return self.rid
561
566
 
562
567
  def __getitem__(self, i):
568
+ if self.is_cross_encoder_request:
569
+ return EmbeddingReqInput(
570
+ text=[self.text[i]] if self.text is not None else None,
571
+ input_ids=None,
572
+ image_data=None,
573
+ sampling_params=self.sampling_params[i],
574
+ rid=self.rid[i],
575
+ is_cross_encoder_request=True,
576
+ )
577
+
563
578
  return EmbeddingReqInput(
564
579
  text=self.text[i] if self.text is not None else None,
565
580
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
@@ -579,6 +594,8 @@ class TokenizedEmbeddingReqInput:
579
594
  input_ids: List[int]
580
595
  # The image inputs
581
596
  image_inputs: dict
597
+ # The token type ids
598
+ token_type_ids: List[int]
582
599
  # Dummy sampling params for compatibility
583
600
  sampling_params: SamplingParams
584
601
 
@@ -794,7 +811,9 @@ class GetWeightsByNameReqOutput:
794
811
 
795
812
  @dataclass
796
813
  class ReleaseMemoryOccupationReqInput:
797
- pass
814
+ # Optional tags to identify the memory region, which is primarily used for RL
815
+ # Currently we only support `weights` and `kv_cache`
816
+ tags: Optional[List[str]] = None
798
817
 
799
818
 
800
819
  @dataclass
@@ -804,7 +823,9 @@ class ReleaseMemoryOccupationReqOutput:
804
823
 
805
824
  @dataclass
806
825
  class ResumeMemoryOccupationReqInput:
807
- pass
826
+ # Optional tags to identify the memory region, which is primarily used for RL
827
+ # Currently we only support `weights` and `kv_cache`
828
+ tags: Optional[List[str]] = None
808
829
 
809
830
 
810
831
  @dataclass
@@ -146,7 +146,7 @@ class BaseMultimodalProcessor(ABC):
146
146
  request_obj,
147
147
  max_req_input_len,
148
148
  **kwargs,
149
- ):
149
+ ) -> Optional[Dict[str, Any]]:
150
150
  pass
151
151
 
152
152
  def get_estimated_frames_list(self, image_data):
@@ -261,7 +261,7 @@ class BaseMultimodalProcessor(ABC):
261
261
 
262
262
  def load_mm_data(
263
263
  self,
264
- prompt: str,
264
+ prompt: str | List[int],
265
265
  multimodal_tokens: MultimodalSpecialTokens,
266
266
  max_req_input_len: int,
267
267
  image_data: Optional[list] = None,
@@ -0,0 +1,85 @@
1
+ from typing import Any, Dict, List, Optional, Type, cast
2
+
3
+ import torch.nn as nn
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.processing_utils import ProcessorMixin
6
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
7
+
8
+ from sglang.srt.managers.io_struct import (
9
+ EmbeddingReqInput,
10
+ GenerateReqInput,
11
+ ImageDataItem,
12
+ )
13
+ from sglang.srt.managers.multimodal_processors.base_processor import (
14
+ BaseMultimodalProcessor,
15
+ MultimodalSpecialTokens,
16
+ )
17
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
18
+ from sglang.srt.models.vila import VILAForConditionalGeneration
19
+ from sglang.srt.server_args import ServerArgs
20
+
21
+
22
+ class VILAProcessor(ProcessorMixin):
23
+ """A stub class for the VILA processor."""
24
+
25
+ tokenizer: PreTrainedTokenizerBase
26
+
27
+
28
+ class VILAMultimodalProcessor(BaseMultimodalProcessor):
29
+ models: List[Type[nn.Module]] = [VILAForConditionalGeneration]
30
+
31
+ _processor: VILAProcessor
32
+
33
+ def __init__(
34
+ self,
35
+ hf_config: PretrainedConfig,
36
+ server_args: ServerArgs,
37
+ _processor: VILAProcessor,
38
+ ) -> None:
39
+ super().__init__(hf_config, server_args, _processor)
40
+
41
+ async def process_mm_data_async(
42
+ self,
43
+ image_data: Optional[ImageDataItem | List[ImageDataItem]],
44
+ input_text: str | List[int],
45
+ request_obj: GenerateReqInput | EmbeddingReqInput,
46
+ max_req_input_len: int,
47
+ **kwargs,
48
+ ) -> Optional[Dict[str, Any]]:
49
+ if not image_data:
50
+ return None
51
+
52
+ if not isinstance(image_data, list):
53
+ image_data = [image_data]
54
+
55
+ mm_data = self.load_mm_data(
56
+ prompt=input_text,
57
+ multimodal_tokens=MultimodalSpecialTokens(
58
+ image_token=self._processor.tokenizer.image_token
59
+ ),
60
+ max_req_input_len=max_req_input_len,
61
+ image_data=image_data,
62
+ )
63
+
64
+ inputs = self.process_mm_data(
65
+ input_text=mm_data.input_text,
66
+ images=mm_data.images,
67
+ )
68
+
69
+ image_offsets = self.get_mm_items_offset(
70
+ input_ids=inputs.input_ids[0],
71
+ mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
72
+ )
73
+
74
+ mm_items: List[MultimodalDataItem] = [
75
+ MultimodalDataItem(
76
+ modality=Modality.IMAGE,
77
+ image_offsets=image_offsets,
78
+ pixel_values=inputs.pixel_values,
79
+ )
80
+ ]
81
+
82
+ return dict(
83
+ input_ids=inputs.input_ids[0].tolist(),
84
+ mm_items=mm_items,
85
+ )