sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,6 @@ from transformers import (
11
11
  ProcessorMixin,
12
12
  )
13
13
 
14
- from sglang.srt.configs.deepseek_ocr import BASE_SIZE, IMAGE_SIZE, MAX_CROPS, MIN_CROPS
15
-
16
14
 
17
15
  def select_best_resolution(image_size, candidate_resolutions):
18
16
  # used for cropping
@@ -63,7 +61,6 @@ class DictOutput(object):
63
61
  class VLChatProcessorOutput(DictOutput):
64
62
  input_ids: torch.LongTensor
65
63
  target_ids: torch.LongTensor
66
- images_crop: torch.LongTensor
67
64
  pixel_values: (
68
65
  torch.Tensor
69
66
  ) # rename from "images" to "pixel_values" for compatibility
@@ -107,68 +104,6 @@ class ImageTransform(object):
107
104
  return x
108
105
 
109
106
 
110
- def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
111
- best_ratio_diff = float("inf")
112
- best_ratio = (1, 1)
113
- area = width * height
114
- for ratio in target_ratios:
115
- target_aspect_ratio = ratio[0] / ratio[1]
116
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
117
- if ratio_diff < best_ratio_diff:
118
- best_ratio_diff = ratio_diff
119
- best_ratio = ratio
120
- elif ratio_diff == best_ratio_diff:
121
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
122
- best_ratio = ratio
123
- return best_ratio
124
-
125
-
126
- def dynamic_preprocess(
127
- image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
128
- ):
129
- orig_width, orig_height = image.size
130
- aspect_ratio = orig_width / orig_height
131
-
132
- # calculate the existing image aspect ratio
133
- target_ratios = set(
134
- (i, j)
135
- for n in range(min_num, max_num + 1)
136
- for i in range(1, n + 1)
137
- for j in range(1, n + 1)
138
- if i * j <= max_num and i * j >= min_num
139
- )
140
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
141
-
142
- # find the closest aspect ratio to the target
143
- target_aspect_ratio = find_closest_aspect_ratio(
144
- aspect_ratio, target_ratios, orig_width, orig_height, image_size
145
- )
146
-
147
- # calculate the target width and height
148
- target_width = image_size * target_aspect_ratio[0]
149
- target_height = image_size * target_aspect_ratio[1]
150
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
151
-
152
- # resize the image
153
- resized_img = image.resize((target_width, target_height))
154
- processed_images = []
155
- for i in range(blocks):
156
- box = (
157
- (i % (target_width // image_size)) * image_size,
158
- (i // (target_width // image_size)) * image_size,
159
- ((i % (target_width // image_size)) + 1) * image_size,
160
- ((i // (target_width // image_size)) + 1) * image_size,
161
- )
162
- # split the image
163
- split_img = resized_img.crop(box)
164
- processed_images.append(split_img)
165
- assert len(processed_images) == blocks
166
- if use_thumbnail and len(processed_images) != 1:
167
- thumbnail_img = image.resize((image_size, image_size))
168
- processed_images.append(thumbnail_img)
169
- return processed_images, target_aspect_ratio
170
-
171
-
172
107
  class DeepseekVLV2Processor(ProcessorMixin):
173
108
  tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
174
109
  attributes = ["tokenizer"]
@@ -198,7 +133,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
198
133
  self.image_std = image_std
199
134
  self.normalize = normalize
200
135
  self.downsample_ratio = downsample_ratio
201
- self.base_size = BASE_SIZE
136
+
202
137
  self.image_transform = ImageTransform(
203
138
  mean=image_mean, std=image_std, normalize=normalize
204
139
  )
@@ -241,7 +176,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
241
176
  **kwargs,
242
177
  )
243
178
 
244
- def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
179
+ def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
245
180
  """play the role of format_messages_v2 and get_images_info in the last version"""
246
181
  tokenized_data = []
247
182
  masked_tokenized_data = [] # labels
@@ -251,34 +186,35 @@ class DeepseekVLV2Processor(ProcessorMixin):
251
186
 
252
187
  image_index = 0
253
188
  image_token_cnt = messages.count(self.image_token)
254
- (
255
- input_ids,
256
- images,
257
- images_crop,
258
- seq_mask,
259
- spatial_crop,
260
- num_image_tokens,
261
- image_shapes,
262
- ) = self.tokenize_with_images(
189
+ tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
263
190
  messages,
264
191
  pil_images[image_index : image_index + image_token_cnt],
265
192
  bos=True,
266
193
  eos=True,
267
194
  cropping=len(pil_images) <= 2,
195
+ max_req_input_len=max_req_input_len,
268
196
  )
269
197
 
270
198
  image_index = image_token_cnt
199
+ tokenized_data += tokenized_str
200
+ if self.mask_prompt:
201
+ masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
202
+ else:
203
+ masked_tokenized_data += tokenized_str
271
204
  images_list += images
272
205
  images_seq_mask += seq_mask
273
- images_spatial_crop = spatial_crop
206
+ images_spatial_crop += spatial_crop
207
+
208
+ assert len(tokenized_data) == len(
209
+ images_seq_mask
210
+ ), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
274
211
 
275
212
  return (
276
- input_ids,
213
+ tokenized_data,
277
214
  masked_tokenized_data,
278
215
  images_list,
279
216
  images_seq_mask,
280
217
  images_spatial_crop,
281
- images_crop,
282
218
  )
283
219
 
284
220
  @property
@@ -315,7 +251,6 @@ class DeepseekVLV2Processor(ProcessorMixin):
315
251
  inference_mode: bool = True,
316
252
  system_prompt: str = "",
317
253
  max_req_input_len: int = -1,
318
- cropping: bool = True,
319
254
  **kwargs,
320
255
  ):
321
256
  """
@@ -339,22 +274,47 @@ class DeepseekVLV2Processor(ProcessorMixin):
339
274
  - num_image_tokens (List[int]): the number of image tokens
340
275
  """
341
276
 
342
- prompt = conversations or prompt
277
+ assert (
278
+ prompt is None or conversations is None
279
+ ), "prompt and conversations cannot be used at the same time."
280
+
343
281
  (
344
- input_ids,
282
+ tokenized_str,
345
283
  masked_tokenized_str,
346
284
  images_list,
347
285
  images_seq_mask,
348
286
  images_spatial_crop,
349
- images_crop,
350
- ) = self.format_messages_v2(prompt, images, max_req_input_len)
287
+ ) = self.format_messages_v2(conversations, images, max_req_input_len)
351
288
 
289
+ assert (
290
+ len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
291
+ ), (
292
+ f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
293
+ f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
294
+ )
295
+
296
+ input_ids = torch.LongTensor(tokenized_str)
352
297
  target_ids = torch.LongTensor(masked_tokenized_str)
298
+ images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
299
+
300
+ # set input_ids < 0 | input_ids == self.image_token_id as ignore_id
301
+ target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
302
+ self.ignore_id
303
+ )
304
+ input_ids[input_ids < 0] = self.pad_id
305
+
306
+ if inference_mode:
307
+ assert input_ids[-1] == self.eos_id
308
+ input_ids = input_ids[:-1]
309
+ target_ids = target_ids[:-1]
310
+ images_seq_mask = images_seq_mask[:-1]
353
311
 
354
312
  if len(images_list) == 0:
355
313
  images = torch.zeros((1, 3, self.image_size, self.image_size))
314
+ images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
356
315
  else:
357
316
  images = torch.stack(images_list, dim=0)
317
+ images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
358
318
 
359
319
  images_spatial_crop = torch.stack(
360
320
  [images_spatial_crop], dim=0
@@ -363,7 +323,6 @@ class DeepseekVLV2Processor(ProcessorMixin):
363
323
  prepare = VLChatProcessorOutput(
364
324
  input_ids=input_ids,
365
325
  target_ids=target_ids,
366
- images_crop=images_crop,
367
326
  pixel_values=images,
368
327
  images_seq_mask=images_seq_mask,
369
328
  images_spatial_crop=images_spatial_crop,
@@ -381,14 +340,10 @@ class DeepseekVLV2Processor(ProcessorMixin):
381
340
  inference_mode: bool = True,
382
341
  system_prompt: str = "",
383
342
  max_req_input_len: int = -1,
384
- text: list[str] = None,
385
343
  **kwargs,
386
344
  ):
387
- assert text is None or isinstance(text, list)
388
- if text is not None:
389
- text = text[0]
390
345
  prepare = self.process_one(
391
- prompt=prompt or text,
346
+ prompt=prompt,
392
347
  conversations=conversations,
393
348
  images=images,
394
349
  apply_sft_format=apply_sft_format,
@@ -413,83 +368,85 @@ class DeepseekVLV2Processor(ProcessorMixin):
413
368
  bos: bool = True,
414
369
  eos: bool = True,
415
370
  cropping: bool = True,
371
+ max_req_input_len: int = -1,
416
372
  ):
417
373
  """Tokenize text with <image> tags."""
418
-
419
- conversation = conversation
420
- assert conversation.count(self.image_token) == len(images)
374
+ images_list, images_seq_mask, images_spatial_crop = [], [], []
421
375
  text_splits = conversation.split(self.image_token)
422
- images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
423
- [],
424
- [],
425
- [],
426
- [],
427
- )
428
- image_shapes = []
429
- num_image_tokens = []
430
376
  tokenized_str = []
431
377
  for text_sep, image in zip(text_splits, images):
432
378
  """encode text_sep"""
433
379
  tokenized_sep = self.encode(text_sep, bos=False, eos=False)
434
-
435
380
  tokenized_str += tokenized_sep
436
381
  images_seq_mask += [False] * len(tokenized_sep)
437
382
 
438
- image_shapes.append(image.size)
439
-
440
- if image.size[0] <= 640 and image.size[1] <= 640:
441
- crop_ratio = [1, 1]
383
+ """select best resolution for anyres"""
384
+ if cropping:
385
+ best_width, best_height = select_best_resolution(
386
+ image.size, self.candidate_resolutions
387
+ )
442
388
  else:
443
- if cropping:
444
- images_crop_raw, crop_ratio = dynamic_preprocess(
445
- image, image_size=IMAGE_SIZE
446
- )
447
- else:
448
- crop_ratio = [1, 1]
389
+ best_width, best_height = self.image_size, self.image_size
390
+ # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
449
391
 
450
392
  """process the global view"""
451
- if self.image_size <= 640 and not cropping:
452
- image = image.resize((self.image_size, self.image_size))
453
-
454
393
  global_view = ImageOps.pad(
455
394
  image,
456
- (self.base_size, self.base_size),
395
+ (self.image_size, self.image_size),
457
396
  color=tuple(int(x * 255) for x in self.image_transform.mean),
458
397
  )
459
398
  images_list.append(self.image_transform(global_view))
460
399
 
461
- num_width_tiles, num_height_tiles = crop_ratio
462
- images_spatial_crop.append([num_width_tiles, num_height_tiles])
400
+ """process the local views"""
401
+ local_view = ImageOps.pad(
402
+ image,
403
+ (best_width, best_height),
404
+ color=tuple(int(x * 255) for x in self.image_transform.mean),
405
+ )
406
+ for i in range(0, best_height, self.image_size):
407
+ for j in range(0, best_width, self.image_size):
408
+ images_list.append(
409
+ self.image_transform(
410
+ local_view.crop(
411
+ (j, i, j + self.image_size, i + self.image_size)
412
+ )
413
+ )
414
+ )
463
415
 
464
- if num_width_tiles > 1 or num_height_tiles > 1:
465
- for i in range(len(images_crop_raw)):
466
- images_crop_list.append(self.image_transform(images_crop_raw[i]))
416
+ """record height / width crop num"""
417
+ num_width_tiles, num_height_tiles = (
418
+ best_width // self.image_size,
419
+ best_height // self.image_size,
420
+ )
421
+ images_spatial_crop.append([num_width_tiles, num_height_tiles])
467
422
 
468
423
  """add image tokens"""
469
- num_queries = math.ceil(
424
+ h = w = math.ceil(
470
425
  (self.image_size // self.patch_size) / self.downsample_ratio
471
426
  )
472
- num_queries_base = math.ceil(
473
- (self.base_size // self.patch_size) / self.downsample_ratio
427
+ # global views tokens h * (w + 1), 1 is for line separator
428
+ tokenized_image = [self.image_token_id] * h * (w + 1)
429
+ # add a separator between global and local views
430
+ tokenized_image += [self.image_token_id]
431
+ # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
432
+ tokenized_image += (
433
+ [self.image_token_id]
434
+ * (num_height_tiles * h)
435
+ * (num_width_tiles * w + 1)
474
436
  )
475
437
 
476
- tokenized_image = (
477
- [self.image_token_id] * num_queries_base + [self.image_token_id]
478
- ) * num_queries_base
479
- tokenized_image += [self.image_token_id]
480
- if num_width_tiles > 1 or num_height_tiles > 1:
481
- tokenized_image += (
482
- [self.image_token_id] * (num_queries * num_width_tiles)
483
- + [self.image_token_id]
484
- ) * (num_queries * num_height_tiles)
485
438
  tokenized_str += tokenized_image
486
-
487
439
  images_seq_mask += [True] * len(tokenized_image)
488
- num_image_tokens.append(len(tokenized_image))
440
+ # print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
489
441
 
490
442
  """process the last text split"""
491
443
  tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
492
-
444
+ # deal with video, limit with request len
445
+ if max_req_input_len > -1:
446
+ if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
447
+ rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
448
+ tokenized_str = tokenized_str[:rest]
449
+ images_seq_mask = images_seq_mask[:rest]
493
450
  tokenized_str += tokenized_sep
494
451
  images_seq_mask += [False] * len(tokenized_sep)
495
452
 
@@ -505,64 +462,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
505
462
  images_seq_mask
506
463
  ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
507
464
 
508
- masked_tokenized_str = []
509
- for token_index in tokenized_str:
510
- if token_index != self.image_token_id:
511
- masked_tokenized_str.append(token_index)
512
- else:
513
- masked_tokenized_str.append(self.ignore_id)
514
-
515
- assert (
516
- len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
517
- ), (
518
- f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
519
- f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
520
- )
521
- input_ids = torch.LongTensor(tokenized_str)
522
- target_ids = torch.LongTensor(masked_tokenized_str)
523
- images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
524
-
525
- # set input_ids < 0 | input_ids == self.image_token_id as ignore_id
526
- target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
527
- self.ignore_id
528
- )
529
- input_ids[input_ids < 0] = self.pad_id
530
-
531
- inference_mode = True
532
-
533
- if inference_mode:
534
- # Remove the ending eos token
535
- assert input_ids[-1] == self.eos_id
536
- input_ids = input_ids[:-1]
537
- target_ids = target_ids[:-1]
538
- images_seq_mask = images_seq_mask[:-1]
539
-
540
- if len(images_list) == 0:
541
- pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
542
- images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
543
- images_crop = torch.zeros(
544
- (1, 3, self.image_size, self.image_size)
545
- ).unsqueeze(0)
546
- else:
547
- pixel_values = torch.stack(images_list, dim=0)
548
- images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
549
- if images_crop_list:
550
- images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
551
- else:
552
- images_crop = torch.zeros(
553
- (1, 3, self.image_size, self.image_size)
554
- ).unsqueeze(0)
555
-
556
- input_ids = input_ids.unsqueeze(0)
557
- return (
558
- input_ids,
559
- pixel_values,
560
- images_crop,
561
- images_seq_mask,
562
- images_spatial_crop,
563
- num_image_tokens,
564
- image_shapes,
565
- )
465
+ return tokenized_str, images_list, images_seq_mask, images_spatial_crop
566
466
 
567
467
 
568
468
  class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
@@ -647,6 +547,7 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
647
547
 
648
548
 
649
549
  class DeepseekV2Config(PretrainedConfig):
550
+
650
551
  model_type = "deepseek_v2"
651
552
  keys_to_ignore_at_inference = ["past_key_values"]
652
553
 
@@ -0,0 +1,160 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/transformers_utils/configs/kimi_linear.py
2
+ from transformers.configuration_utils import PretrainedConfig
3
+
4
+ from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape
5
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
6
+
7
+
8
+ class KimiLinearConfig(PretrainedConfig):
9
+ model_type = "kimi_linear"
10
+ keys_to_ignore_at_inference = ["past_key_values"]
11
+
12
+ def __init__(
13
+ self,
14
+ model_type="kimi_linear",
15
+ vocab_size=163840,
16
+ hidden_size=4096,
17
+ head_dim=None,
18
+ intermediate_size=11008,
19
+ num_hidden_layers=32,
20
+ num_attention_heads=32,
21
+ num_key_value_heads=None,
22
+ hidden_act="silu",
23
+ initializer_range=0.02,
24
+ rms_norm_eps=1e-6,
25
+ use_cache=True,
26
+ pad_token_id=0,
27
+ bos_token_id=1,
28
+ eos_token_id=2,
29
+ rope_theta=10000.0,
30
+ rope_scaling=None,
31
+ tie_word_embeddings=False,
32
+ moe_intermediate_size: int | None = None,
33
+ moe_renormalize: bool = True,
34
+ moe_router_activation_func: str = "sigmoid",
35
+ num_experts: int | None = None,
36
+ num_experts_per_token: int | None = None,
37
+ num_shared_experts: int = 0,
38
+ routed_scaling_factor: float = 1.0,
39
+ first_k_dense_replace: int = 0,
40
+ moe_layer_freq: int = 1,
41
+ use_grouped_topk: bool = True,
42
+ num_expert_group: int = 1,
43
+ topk_group: int = 1,
44
+ q_lora_rank: int | None = None,
45
+ kv_lora_rank: int | None = None,
46
+ qk_nope_head_dim: int | None = None,
47
+ qk_rope_head_dim: int | None = None,
48
+ v_head_dim: int | None = None,
49
+ mla_use_nope: bool | None = False,
50
+ num_nextn_predict_layers: int = 0,
51
+ linear_attn_config: dict | None = None,
52
+ **kwargs,
53
+ ):
54
+ self.model_type = model_type
55
+ self.vocab_size = vocab_size
56
+ self.hidden_size = hidden_size
57
+ self.head_dim = (
58
+ head_dim if head_dim is not None else hidden_size // num_attention_heads
59
+ )
60
+ self.intermediate_size = intermediate_size
61
+ self.num_hidden_layers = num_hidden_layers
62
+ self.num_attention_heads = num_attention_heads
63
+
64
+ # for backward compatibility
65
+ if num_key_value_heads is None:
66
+ num_key_value_heads = num_attention_heads
67
+
68
+ self.num_key_value_heads = num_key_value_heads
69
+ self.hidden_act = hidden_act
70
+ self.initializer_range = initializer_range
71
+ self.rms_norm_eps = rms_norm_eps
72
+ self.use_cache = use_cache
73
+ self.rope_theta = rope_theta
74
+ self.rope_scaling = rope_scaling
75
+
76
+ self.q_lora_rank = q_lora_rank
77
+ self.kv_lora_rank = kv_lora_rank
78
+ self.qk_nope_head_dim = qk_nope_head_dim
79
+ self.qk_rope_head_dim = qk_rope_head_dim
80
+ self.v_head_dim = v_head_dim
81
+ self.mla_use_nope = mla_use_nope
82
+ # moe config
83
+ self.n_routed_experts = self.num_experts = num_experts
84
+ self.num_experts_per_token = num_experts_per_token
85
+ self.moe_renormalize = moe_renormalize
86
+ self.num_shared_experts = num_shared_experts
87
+ self.routed_scaling_factor = routed_scaling_factor
88
+ self.moe_router_activation_func = moe_router_activation_func
89
+ assert self.moe_router_activation_func in ("softmax", "sigmoid")
90
+ self.moe_intermediate_size = moe_intermediate_size
91
+ self.first_k_dense_replace = first_k_dense_replace
92
+ self.moe_layer_freq = moe_layer_freq
93
+ self.use_grouped_topk = use_grouped_topk
94
+ self.num_expert_group = num_expert_group
95
+ self.topk_group = topk_group
96
+ self.num_nextn_predict_layers = num_nextn_predict_layers
97
+
98
+ if linear_attn_config is not None:
99
+ assert linear_attn_config["kda_layers"] is not None
100
+ assert linear_attn_config["full_attn_layers"] is not None
101
+ self.linear_attn_config = linear_attn_config
102
+
103
+ super().__init__(
104
+ pad_token_id=pad_token_id,
105
+ bos_token_id=bos_token_id,
106
+ eos_token_id=eos_token_id,
107
+ tie_word_embeddings=tie_word_embeddings,
108
+ **kwargs,
109
+ )
110
+
111
+ @property
112
+ def is_mla(self):
113
+ return (
114
+ self.q_lora_rank is not None
115
+ or self.kv_lora_rank is not None
116
+ or self.qk_nope_head_dim is not None
117
+ or self.qk_rope_head_dim is not None
118
+ or self.v_head_dim is not None
119
+ or self.mla_use_nope is True
120
+ )
121
+
122
+ @property
123
+ def is_moe(self):
124
+ return self.num_experts is not None
125
+
126
+ @property
127
+ def is_linear_attn(self) -> bool:
128
+ return not (
129
+ self.linear_attn_config is None
130
+ or (
131
+ isinstance(self.linear_attn_config, dict)
132
+ and self.linear_attn_config["kda_layers"] is not None
133
+ and len(self.linear_attn_config["kda_layers"]) == 0
134
+ )
135
+ )
136
+
137
+ def is_kda_layer(self, layer_idx: int):
138
+ return (
139
+ self.linear_attn_config is not None
140
+ and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
141
+ )
142
+
143
+ @property
144
+ def linear_layer_ids(self):
145
+ return [i for i in range(self.num_hidden_layers) if self.is_kda_layer(i)]
146
+
147
+ @property
148
+ def full_attention_layer_ids(self):
149
+ return [i for i in range(self.num_hidden_layers) if not self.is_kda_layer(i)]
150
+
151
+ @property
152
+ def mamba2_cache_params(self) -> KimiLinearCacheParams:
153
+ shape = KimiLinearStateShape.create(
154
+ tp_world_size=get_attention_tp_size(),
155
+ num_heads=self.linear_attn_config["num_heads"],
156
+ head_dim=self.linear_attn_config["head_dim"],
157
+ conv_kernel_size=self.linear_attn_config["short_conv_kernel_size"],
158
+ )
159
+
160
+ return KimiLinearCacheParams(shape=shape, layers=self.linear_layer_ids)
@@ -14,6 +14,7 @@
14
14
 
15
15
  import os
16
16
  from dataclasses import dataclass, field
17
+ from typing import List, Optional
17
18
 
18
19
  import numpy as np
19
20
  import torch
@@ -115,3 +116,68 @@ class Mamba2CacheParams:
115
116
  int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
116
117
  + int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
117
118
  ) * len(self.layers)
119
+
120
+
121
+ @dataclass(kw_only=True, frozen=True)
122
+ class KimiLinearStateShape:
123
+ conv: List[tuple[int, int]]
124
+ temporal: tuple[int, int, int]
125
+
126
+ num_heads: int
127
+ head_dim: int
128
+ num_k_heads: int
129
+ head_k_dim: int
130
+ conv_kernel: int
131
+ num_spec: int
132
+
133
+ @staticmethod
134
+ def create(
135
+ *,
136
+ tp_world_size: int,
137
+ num_heads: int,
138
+ head_dim: int,
139
+ num_k_heads: Optional[int] = None,
140
+ head_k_dim: Optional[int] = None,
141
+ conv_kernel_size: int = 4,
142
+ num_spec: int = 0,
143
+ ) -> "KimiLinearStateShape":
144
+ if num_k_heads is None:
145
+ num_k_heads = num_heads
146
+ if head_k_dim is None:
147
+ head_k_dim = head_dim
148
+
149
+ proj_size = num_heads * head_dim
150
+ proj_k_size = num_k_heads * head_k_dim
151
+
152
+ conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
153
+ conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
154
+ temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
155
+
156
+ conv_state_shape = conv_state_shape[1], conv_state_shape[0]
157
+ conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
158
+
159
+ return KimiLinearStateShape(
160
+ conv=[conv_state_shape, conv_state_k_shape, conv_state_k_shape],
161
+ temporal=temporal_state_shape,
162
+ num_heads=num_heads,
163
+ head_dim=head_dim,
164
+ num_k_heads=num_k_heads,
165
+ head_k_dim=head_k_dim,
166
+ conv_kernel=conv_kernel_size,
167
+ num_spec=num_spec,
168
+ )
169
+
170
+
171
+ @dataclass(kw_only=True, frozen=True)
172
+ class KimiLinearCacheParams:
173
+ shape: KimiLinearStateShape
174
+ dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
175
+ layers: list[int]
176
+
177
+ @property
178
+ def mamba_cache_per_req(self) -> int:
179
+ return (
180
+ int(np.sum([np.prod(conv_shape) for conv_shape in self.shape.conv]))
181
+ * self.dtype.conv.itemsize
182
+ + int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
183
+ ) * len(self.layers)