sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -268,98 +268,97 @@ class HiCacheController:
268
268
  """
269
269
  Directly write through KV caches to host memory without buffering.
270
270
  """
271
- with torch.cuda.stream(self.write_stream):
272
- while not self.stop_event.is_set():
273
- try:
274
- operation = self.write_queue.get(block=True, timeout=1)
275
- self.mem_pool_host.write_page_all_layers(
276
- operation.host_indices,
277
- operation.device_indices,
278
- self.mem_pool_device,
279
- )
280
- self.write_stream.synchronize()
281
- self.mem_pool_host.complete_io(operation.host_indices)
282
- for node_id in operation.node_ids:
283
- if node_id != 0:
284
- self.ack_write_queue.put(node_id)
285
- except Empty:
286
- continue
287
- except Exception as e:
288
- logger.error(e)
271
+ torch.cuda.set_stream(self.write_stream)
272
+ while not self.stop_event.is_set():
273
+ try:
274
+ operation = self.write_queue.get(block=True, timeout=1)
275
+ self.mem_pool_host.write_page_all_layers(
276
+ operation.host_indices,
277
+ operation.device_indices,
278
+ self.mem_pool_device,
279
+ )
280
+ self.write_stream.synchronize()
281
+ self.mem_pool_host.complete_io(operation.host_indices)
282
+ for node_id in operation.node_ids:
283
+ if node_id != 0:
284
+ self.ack_write_queue.put(node_id)
285
+ except Empty:
286
+ continue
287
+ except Exception as e:
288
+ logger.error(e)
289
289
 
290
290
  def load_thread_func_direct(self):
291
291
  """
292
292
  Directly load KV caches from host memory to device memory without buffering.
293
293
  """
294
- with torch.cuda.stream(self.load_stream):
295
- while not self.stop_event.is_set():
296
- try:
297
- operation = self.load_queue.get(block=True, timeout=1)
298
- # time.sleep(18e-6 * len(operation.host_indices))
299
- operation.data = self.mem_pool_host.get_flat_data(
300
- operation.host_indices
301
- )
302
- self.mem_pool_device.transfer(
303
- operation.device_indices, operation.data
304
- )
305
- self.mem_pool_host.complete_io(operation.host_indices)
306
- for node_id in operation.node_ids:
307
- if node_id != 0:
308
- self.ack_load_queue.put(node_id)
309
- except Empty:
310
- continue
311
- except Exception as e:
312
- logger.error(e)
294
+ torch.cuda.set_stream(self.load_stream)
295
+ while not self.stop_event.is_set():
296
+ try:
297
+ operation = self.load_queue.get(block=True, timeout=1)
298
+ # time.sleep(18e-6 * len(operation.host_indices))
299
+ operation.data = self.mem_pool_host.get_flat_data(
300
+ operation.host_indices
301
+ )
302
+ self.mem_pool_device.transfer(operation.device_indices, operation.data)
303
+ self.mem_pool_host.complete_io(operation.host_indices)
304
+ for node_id in operation.node_ids:
305
+ if node_id != 0:
306
+ self.ack_load_queue.put(node_id)
307
+ except Empty:
308
+ continue
309
+ except Exception as e:
310
+ logger.error(e)
313
311
 
314
312
  def load_thread_func_layer_by_layer(self):
315
313
  """
316
314
  Load KV caches from host memory to device memory layer by layer.
317
315
  """
318
- with torch.cuda.stream(self.load_stream):
319
- while not self.stop_event.is_set():
320
- self.load_cache_event.wait(timeout=1)
321
- if not self.load_cache_event.is_set():
322
- continue
323
- self.load_cache_event.clear()
316
+ torch.cuda.set_stream(self.load_stream)
317
+ while not self.stop_event.is_set():
318
+ self.load_cache_event.wait(timeout=1)
319
+ if not self.load_cache_event.is_set():
320
+ continue
321
+ self.load_cache_event.clear()
324
322
 
325
- batch_operation = None
326
- while self.load_queue.qsize() > 0:
327
- op = self.load_queue.get(block=True)
328
- if batch_operation is None:
329
- batch_operation = op
330
- else:
331
- batch_operation.merge(op)
323
+ batch_operation = None
324
+ while self.load_queue.qsize() > 0:
325
+ op = self.load_queue.get(block=True)
332
326
  if batch_operation is None:
333
- continue
327
+ batch_operation = op
328
+ else:
329
+ batch_operation.merge(op)
330
+ if batch_operation is None:
331
+ continue
334
332
 
335
- self.layer_done_counter.reset()
336
- for i in range(self.mem_pool_host.layer_num):
337
- if self.page_size == 1:
338
- flat_data = self.mem_pool_host.get_flat_data_by_layer(
339
- batch_operation.host_indices, i
340
- )
341
- self.mem_pool_device.transfer_per_layer(
342
- batch_operation.device_indices, flat_data, i
343
- )
344
- else:
345
- self.mem_pool_host.load_page_per_layer(
346
- batch_operation.host_indices,
347
- batch_operation.device_indices,
348
- self.mem_pool_device,
349
- i,
350
- )
351
- self.load_stream.synchronize()
352
- self.layer_done_counter.increment()
353
-
354
- self.mem_pool_host.complete_io(batch_operation.host_indices)
355
- for node_id in batch_operation.node_ids:
356
- if node_id != 0:
357
- self.ack_load_queue.put(node_id)
333
+ self.layer_done_counter.reset()
334
+ for i in range(self.mem_pool_host.layer_num):
335
+ if self.page_size == 1:
336
+ flat_data = self.mem_pool_host.get_flat_data_by_layer(
337
+ batch_operation.host_indices, i
338
+ )
339
+ self.mem_pool_device.transfer_per_layer(
340
+ batch_operation.device_indices, flat_data, i
341
+ )
342
+ else:
343
+ self.mem_pool_host.load_page_per_layer(
344
+ batch_operation.host_indices,
345
+ batch_operation.device_indices,
346
+ self.mem_pool_device,
347
+ i,
348
+ )
349
+ self.load_stream.synchronize()
350
+ self.layer_done_counter.increment()
351
+
352
+ self.mem_pool_host.complete_io(batch_operation.host_indices)
353
+ for node_id in batch_operation.node_ids:
354
+ if node_id != 0:
355
+ self.ack_load_queue.put(node_id)
358
356
 
359
357
  def write_aux_func(self, no_wait=False):
360
358
  """
361
359
  Auxiliary function to prepare the buffer for write operations.
362
360
  """
361
+ torch.cuda.set_stream(self.write_stream)
363
362
 
364
363
  def _to_op(op_):
365
364
  assert op_.device_indices.is_cuda, "Device indices should be on GPU"
@@ -370,44 +369,42 @@ class HiCacheController:
370
369
  return op_
371
370
 
372
371
  buffer = None
373
- with torch.cuda.stream(self.write_stream):
374
- while not self.stop_event.is_set():
375
- try:
376
- operation = self.write_queue.get(block=True, timeout=1)
377
- factor = (
378
- len(operation.device_indices)
379
- // self.write_buffer.max_buffer_size
380
- )
372
+ while not self.stop_event.is_set():
373
+ try:
374
+ operation = self.write_queue.get(block=True, timeout=1)
375
+ factor = (
376
+ len(operation.device_indices) // self.write_buffer.max_buffer_size
377
+ )
381
378
 
382
- if factor >= 1:
383
- if buffer is not None:
384
- _to_op(buffer)
385
- buffer = None
386
-
387
- if factor < 2:
388
- _to_op(operation)
389
- else:
390
- split_ops = operation.split(factor)
391
- for op_ in split_ops:
392
- _to_op(op_)
393
- continue
394
-
395
- if buffer is None:
396
- buffer = operation
397
- else:
398
- buffer.merge(operation)
399
- if (
400
- no_wait
401
- or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
402
- or self.write_queue.empty()
403
- or self.write_buffer.empty()
404
- ):
379
+ if factor >= 1:
380
+ if buffer is not None:
405
381
  _to_op(buffer)
406
382
  buffer = None
407
- except Empty:
383
+
384
+ if factor < 2:
385
+ _to_op(operation)
386
+ else:
387
+ split_ops = operation.split(factor)
388
+ for op_ in split_ops:
389
+ _to_op(op_)
408
390
  continue
409
- except Exception as e:
410
- logger.error(e)
391
+
392
+ if buffer is None:
393
+ buffer = operation
394
+ else:
395
+ buffer.merge(operation)
396
+ if (
397
+ no_wait
398
+ or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
399
+ or self.write_queue.empty()
400
+ or self.write_buffer.empty()
401
+ ):
402
+ _to_op(buffer)
403
+ buffer = None
404
+ except Empty:
405
+ continue
406
+ except Exception as e:
407
+ logger.error(e)
411
408
 
412
409
  def load_aux_func(self):
413
410
  """
@@ -484,19 +481,18 @@ class HiCacheController:
484
481
  aux_thread.join()
485
482
 
486
483
  def load_thread_func_buffer(self):
484
+ torch.cuda.set_stream(self.load_stream)
487
485
  aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
488
486
  aux_thread.start()
489
-
490
- with torch.cuda.stream(self.load_stream):
491
- while not self.stop_event.is_set():
492
- operation = self.load_buffer.get()
493
- if operation is None:
494
- continue
495
- self.mem_pool_device.transfer(operation.device_indices, operation.data)
496
- self.mem_pool_host.complete_io(operation.host_indices)
497
- for node_id in operation.node_ids:
498
- if node_id != 0:
499
- self.ack_load_queue.put(node_id)
487
+ while not self.stop_event.is_set():
488
+ operation = self.load_buffer.get()
489
+ if operation is None:
490
+ continue
491
+ self.mem_pool_device.transfer(operation.device_indices, operation.data)
492
+ self.mem_pool_host.complete_io(operation.host_indices)
493
+ for node_id in operation.node_ids:
494
+ if node_id != 0:
495
+ self.ack_load_queue.put(node_id)
500
496
  aux_thread.join()
501
497
 
502
498
  def evict_device(
@@ -181,44 +181,62 @@ class DataParallelController:
181
181
  enable=server_args.enable_memory_saver
182
182
  )
183
183
 
184
- # Launch tensor parallel scheduler processes
185
184
  scheduler_pipe_readers = []
186
- tp_size_per_node = server_args.tp_size // server_args.nnodes
185
+
186
+ nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
187
+ tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
187
188
  tp_rank_range = range(
188
- tp_size_per_node * server_args.node_rank,
189
- tp_size_per_node * (server_args.node_rank + 1),
189
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
190
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
191
+ )
192
+
193
+ pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
194
+ pp_rank_range = range(
195
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
196
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
190
197
  )
191
- for tp_rank in tp_rank_range:
192
- rank_port_args = port_args
193
-
194
- if server_args.enable_dp_attention:
195
- # dp attention has different sharding logic
196
- _, _, dp_rank = compute_dp_attention_world_info(
197
- server_args.enable_dp_attention,
198
- tp_rank,
199
- server_args.tp_size,
200
- server_args.dp_size,
198
+
199
+ for pp_rank in pp_rank_range:
200
+ for tp_rank in tp_rank_range:
201
+ rank_port_args = port_args
202
+
203
+ if server_args.enable_dp_attention:
204
+ # dp attention has different sharding logic
205
+ _, _, dp_rank = compute_dp_attention_world_info(
206
+ server_args.enable_dp_attention,
207
+ tp_rank,
208
+ server_args.tp_size,
209
+ server_args.dp_size,
210
+ )
211
+ # compute zmq ports for this dp rank
212
+ rank_port_args = PortArgs.init_new(server_args, dp_rank)
213
+ # Data parallelism resues the tensor parallelism group,
214
+ # so all dp ranks should use the same nccl port.
215
+ rank_port_args.nccl_port = port_args.nccl_port
216
+
217
+ reader, writer = mp.Pipe(duplex=False)
218
+ gpu_id = (
219
+ server_args.base_gpu_id
220
+ + base_gpu_id
221
+ + ((pp_rank % pp_size_per_node) * tp_size_per_node)
222
+ + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
201
223
  )
202
- # compute zmq ports for this dp rank
203
- rank_port_args = PortArgs.init_new(server_args, dp_rank)
204
- # Data parallelism resues the tensor parallelism group,
205
- # so all dp ranks should use the same nccl port.
206
- rank_port_args.nccl_port = port_args.nccl_port
207
-
208
- reader, writer = mp.Pipe(duplex=False)
209
- gpu_id = (
210
- server_args.base_gpu_id
211
- + base_gpu_id
212
- + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
213
- )
214
- proc = mp.Process(
215
- target=run_scheduler_process,
216
- args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
217
- )
218
- with memory_saver_adapter.configure_subprocess():
219
- proc.start()
220
- self.scheduler_procs.append(proc)
221
- scheduler_pipe_readers.append(reader)
224
+ proc = mp.Process(
225
+ target=run_scheduler_process,
226
+ args=(
227
+ server_args,
228
+ rank_port_args,
229
+ gpu_id,
230
+ tp_rank,
231
+ pp_rank,
232
+ dp_rank,
233
+ writer,
234
+ ),
235
+ )
236
+ with memory_saver_adapter.configure_subprocess():
237
+ proc.start()
238
+ self.scheduler_procs.append(proc)
239
+ scheduler_pipe_readers.append(reader)
222
240
 
223
241
  # Wait for model to finish loading
224
242
  scheduler_info = []
@@ -790,6 +790,16 @@ class ResumeMemoryOccupationReqOutput:
790
790
  pass
791
791
 
792
792
 
793
+ @dataclass
794
+ class SlowDownReqInput:
795
+ forward_sleep_time: Optional[float]
796
+
797
+
798
+ @dataclass
799
+ class SlowDownReqOutput:
800
+ pass
801
+
802
+
793
803
  @dataclass
794
804
  class AbortReq:
795
805
  # The request id
@@ -8,6 +8,7 @@ from typing import List, Optional
8
8
 
9
9
  import numpy as np
10
10
  import PIL
11
+ import torch
11
12
  from PIL import Image
12
13
  from transformers import BaseImageProcessorFast
13
14
 
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
89
90
  return_tensors="pt",
90
91
  **kwargs,
91
92
  )
93
+ if "pixel_values" in result and isinstance(
94
+ result["pixel_values"], torch.Tensor
95
+ ):
96
+ result["pixel_values"] = result["pixel_values"].to("cpu")
92
97
  return result
93
98
 
94
99
  @abstractmethod
@@ -0,0 +1,232 @@
1
+ # Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
2
+
3
+ import numpy as np
4
+ import torch
5
+ from decord import VideoReader, cpu
6
+ from numpy.distutils.cpuinfo import cpu
7
+ from PIL import Image
8
+
9
+ from sglang.srt.managers.multimodal_processors.base_processor import (
10
+ BaseMultimodalProcessor,
11
+ MultimodalSpecialTokens,
12
+ )
13
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
14
+ from sglang.srt.models.internvl import InternVLChatModel
15
+
16
+
17
+ class InternVLImageProcessor(BaseMultimodalProcessor):
18
+ models = [InternVLChatModel]
19
+
20
+ def __init__(self, hf_config, server_args, _image_processor):
21
+ super().__init__(hf_config, server_args, _image_processor)
22
+ image_size = hf_config.force_image_size or hf_config.vision_config.image_size
23
+ patch_size = hf_config.vision_config.patch_size
24
+
25
+ self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
26
+ self.IMG_START_TOKEN = "<img>"
27
+ self.IMG_END_TOKEN = "</img>"
28
+ self.IMG_TOKEN = "<image>"
29
+ self.num_image_token = int(
30
+ (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
31
+ )
32
+
33
+ tokenizer = self._processor
34
+ self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
35
+ self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
36
+ self.img_context_token_id = tokenizer.convert_tokens_to_ids(
37
+ self.IMG_CONTEXT_TOKEN
38
+ )
39
+
40
+ @staticmethod
41
+ def build_transform(input_size):
42
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_STD = (0.229, 0.224, 0.225)
44
+
45
+ def resize_image(img, size):
46
+ return img.resize((size, size), Image.Resampling.BICUBIC)
47
+
48
+ def to_tensor(img):
49
+ # Convert PIL Image to numpy array
50
+ img_array = np.array(img).astype(np.float32) / 255.0
51
+ # Convert HWC to CHW format
52
+ img_array = img_array.transpose(2, 0, 1)
53
+ return torch.from_numpy(img_array)
54
+
55
+ def normalize(tensor, mean, std):
56
+ mean = torch.tensor(mean).view(-1, 1, 1)
57
+ std = torch.tensor(std).view(-1, 1, 1)
58
+ return (tensor - mean) / std
59
+
60
+ def transform(img):
61
+ img = img.convert("RGB") if img.mode != "RGB" else img
62
+ img = resize_image(img, input_size)
63
+ tensor = to_tensor(img)
64
+ tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
65
+ return tensor
66
+
67
+ return transform
68
+
69
+ @staticmethod
70
+ def dynamic_preprocess(
71
+ image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
72
+ ):
73
+
74
+ def find_closest_aspect_ratio(
75
+ aspect_ratio, target_ratios, width, height, image_size
76
+ ):
77
+ best_ratio_diff = float("inf")
78
+ best_ratio = (1, 1)
79
+ area = width * height
80
+ for ratio in target_ratios:
81
+ target_aspect_ratio = ratio[0] / ratio[1]
82
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
83
+ if ratio_diff < best_ratio_diff:
84
+ best_ratio_diff = ratio_diff
85
+ best_ratio = ratio
86
+ elif ratio_diff == best_ratio_diff:
87
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
88
+ best_ratio = ratio
89
+ return best_ratio
90
+
91
+ orig_width, orig_height = image.size
92
+ aspect_ratio = orig_width / orig_height
93
+
94
+ # calculate the existing image aspect ratio
95
+ target_ratios = set(
96
+ (i, j)
97
+ for n in range(min_num, max_num + 1)
98
+ for i in range(1, n + 1)
99
+ for j in range(1, n + 1)
100
+ if i * j <= max_num and i * j >= min_num
101
+ )
102
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
103
+
104
+ # find the closest aspect ratio to the target
105
+ target_aspect_ratio = find_closest_aspect_ratio(
106
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
107
+ )
108
+
109
+ # calculate the target width and height
110
+ target_width = image_size * target_aspect_ratio[0]
111
+ target_height = image_size * target_aspect_ratio[1]
112
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
113
+
114
+ # resize the image
115
+ resized_img = image.resize((target_width, target_height))
116
+ processed_images = []
117
+ for i in range(blocks):
118
+ box = (
119
+ (i % (target_width // image_size)) * image_size,
120
+ (i // (target_width // image_size)) * image_size,
121
+ ((i % (target_width // image_size)) + 1) * image_size,
122
+ ((i // (target_width // image_size)) + 1) * image_size,
123
+ )
124
+ # split the image
125
+ split_img = resized_img.crop(box)
126
+ processed_images.append(split_img)
127
+ assert len(processed_images) == blocks
128
+ if use_thumbnail and len(processed_images) != 1:
129
+ thumbnail_img = image.resize((image_size, image_size))
130
+ processed_images.append(thumbnail_img)
131
+ return processed_images
132
+
133
+ @staticmethod
134
+ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
135
+ if bound:
136
+ start, end = bound[0], bound[1]
137
+ else:
138
+ start, end = -100000, 100000
139
+ start_idx = max(first_idx, round(start * fps))
140
+ end_idx = min(round(end * fps), max_frame)
141
+ seg_size = float(end_idx - start_idx) / num_segments
142
+ frame_indices = np.array(
143
+ [
144
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
145
+ for idx in range(num_segments)
146
+ ]
147
+ )
148
+ return frame_indices
149
+
150
+ @staticmethod
151
+ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
152
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
153
+ max_frame = len(vr) - 1
154
+ fps = float(vr.get_avg_fps())
155
+
156
+ pixel_values_list, num_patches_list = [], []
157
+ transform = InternVLImageProcessor.build_transform(input_size=input_size)
158
+ frame_indices = InternVLImageProcessor.get_index(
159
+ bound, fps, max_frame, first_idx=0, num_segments=num_segments
160
+ )
161
+ for frame_index in frame_indices:
162
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
163
+ img = InternVLImageProcessor.dynamic_preprocess(
164
+ img, image_size=input_size, use_thumbnail=True, max_num=max_num
165
+ )
166
+ pixel_values = [transform(tile) for tile in img]
167
+ pixel_values = torch.stack(pixel_values)
168
+ num_patches_list.append(pixel_values.shape[0])
169
+ pixel_values_list.append(pixel_values)
170
+ pixel_values = torch.cat(pixel_values_list)
171
+ return pixel_values, num_patches_list
172
+
173
+ async def process_mm_data_async(
174
+ self, image_data, input_text, request_obj, max_req_input_len, **kwargs
175
+ ):
176
+ if not image_data:
177
+ return None
178
+
179
+ base_output = self.load_mm_data(
180
+ prompt=input_text,
181
+ image_data=image_data,
182
+ multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN),
183
+ max_req_input_len=max_req_input_len,
184
+ discard_alpha_channel=True,
185
+ )
186
+
187
+ def process_image_internvl(image, input_size=448, max_num=12):
188
+ transform = InternVLImageProcessor.build_transform(input_size=input_size)
189
+ images = InternVLImageProcessor.dynamic_preprocess(
190
+ image, image_size=input_size, use_thumbnail=True, max_num=max_num
191
+ )
192
+ pixel_values = [transform(image) for image in images]
193
+ pixel_values = torch.stack(pixel_values)
194
+ return pixel_values
195
+
196
+ num_patches_list = []
197
+ pixel_values = []
198
+ # Process each input with allocated frames
199
+ for image_index, (image) in enumerate(base_output.images):
200
+ try:
201
+ # TODO: video input
202
+ raw_image = process_image_internvl(image)
203
+ pixel_value = [raw_image.to(torch.bfloat16).cuda()]
204
+ pixel_values += pixel_value
205
+ num_patches = raw_image.shape[0]
206
+ num_patches_list += [num_patches]
207
+
208
+ except FileNotFoundError as e:
209
+ print(e)
210
+ return None
211
+
212
+ pixel_values = torch.cat(pixel_values, dim=0)
213
+ items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
214
+
215
+ for idx, num_patches in enumerate(num_patches_list):
216
+ image_tokens = (
217
+ self.IMG_START_TOKEN
218
+ + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
219
+ + self.IMG_END_TOKEN
220
+ )
221
+ input_text = input_text.replace("<image>", image_tokens, 1)
222
+
223
+ tokenizer = self._processor
224
+ return {
225
+ "input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
226
+ .flatten()
227
+ .tolist(),
228
+ "mm_items": items,
229
+ "im_start_id": self.img_start_token_id,
230
+ "im_end_id": self.img_end_token_id,
231
+ "im_token_id": self.img_context_token_id,
232
+ }