sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -268,98 +268,97 @@ class HiCacheController:
268
268
  """
269
269
  Directly write through KV caches to host memory without buffering.
270
270
  """
271
- with torch.cuda.stream(self.write_stream):
272
- while not self.stop_event.is_set():
273
- try:
274
- operation = self.write_queue.get(block=True, timeout=1)
275
- self.mem_pool_host.write_page_all_layers(
276
- operation.host_indices,
277
- operation.device_indices,
278
- self.mem_pool_device,
279
- )
280
- self.write_stream.synchronize()
281
- self.mem_pool_host.complete_io(operation.host_indices)
282
- for node_id in operation.node_ids:
283
- if node_id != 0:
284
- self.ack_write_queue.put(node_id)
285
- except Empty:
286
- continue
287
- except Exception as e:
288
- logger.error(e)
271
+ torch.cuda.set_stream(self.write_stream)
272
+ while not self.stop_event.is_set():
273
+ try:
274
+ operation = self.write_queue.get(block=True, timeout=1)
275
+ self.mem_pool_host.write_page_all_layers(
276
+ operation.host_indices,
277
+ operation.device_indices,
278
+ self.mem_pool_device,
279
+ )
280
+ self.write_stream.synchronize()
281
+ self.mem_pool_host.complete_io(operation.host_indices)
282
+ for node_id in operation.node_ids:
283
+ if node_id != 0:
284
+ self.ack_write_queue.put(node_id)
285
+ except Empty:
286
+ continue
287
+ except Exception as e:
288
+ logger.error(e)
289
289
 
290
290
  def load_thread_func_direct(self):
291
291
  """
292
292
  Directly load KV caches from host memory to device memory without buffering.
293
293
  """
294
- with torch.cuda.stream(self.load_stream):
295
- while not self.stop_event.is_set():
296
- try:
297
- operation = self.load_queue.get(block=True, timeout=1)
298
- # time.sleep(18e-6 * len(operation.host_indices))
299
- operation.data = self.mem_pool_host.get_flat_data(
300
- operation.host_indices
301
- )
302
- self.mem_pool_device.transfer(
303
- operation.device_indices, operation.data
304
- )
305
- self.mem_pool_host.complete_io(operation.host_indices)
306
- for node_id in operation.node_ids:
307
- if node_id != 0:
308
- self.ack_load_queue.put(node_id)
309
- except Empty:
310
- continue
311
- except Exception as e:
312
- logger.error(e)
294
+ torch.cuda.set_stream(self.load_stream)
295
+ while not self.stop_event.is_set():
296
+ try:
297
+ operation = self.load_queue.get(block=True, timeout=1)
298
+ # time.sleep(18e-6 * len(operation.host_indices))
299
+ operation.data = self.mem_pool_host.get_flat_data(
300
+ operation.host_indices
301
+ )
302
+ self.mem_pool_device.transfer(operation.device_indices, operation.data)
303
+ self.mem_pool_host.complete_io(operation.host_indices)
304
+ for node_id in operation.node_ids:
305
+ if node_id != 0:
306
+ self.ack_load_queue.put(node_id)
307
+ except Empty:
308
+ continue
309
+ except Exception as e:
310
+ logger.error(e)
313
311
 
314
312
  def load_thread_func_layer_by_layer(self):
315
313
  """
316
314
  Load KV caches from host memory to device memory layer by layer.
317
315
  """
318
- with torch.cuda.stream(self.load_stream):
319
- while not self.stop_event.is_set():
320
- self.load_cache_event.wait(timeout=1)
321
- if not self.load_cache_event.is_set():
322
- continue
323
- self.load_cache_event.clear()
316
+ torch.cuda.set_stream(self.load_stream)
317
+ while not self.stop_event.is_set():
318
+ self.load_cache_event.wait(timeout=1)
319
+ if not self.load_cache_event.is_set():
320
+ continue
321
+ self.load_cache_event.clear()
324
322
 
325
- batch_operation = None
326
- while self.load_queue.qsize() > 0:
327
- op = self.load_queue.get(block=True)
328
- if batch_operation is None:
329
- batch_operation = op
330
- else:
331
- batch_operation.merge(op)
323
+ batch_operation = None
324
+ while self.load_queue.qsize() > 0:
325
+ op = self.load_queue.get(block=True)
332
326
  if batch_operation is None:
333
- continue
327
+ batch_operation = op
328
+ else:
329
+ batch_operation.merge(op)
330
+ if batch_operation is None:
331
+ continue
334
332
 
335
- self.layer_done_counter.reset()
336
- for i in range(self.mem_pool_host.layer_num):
337
- if self.page_size == 1:
338
- flat_data = self.mem_pool_host.get_flat_data_by_layer(
339
- batch_operation.host_indices, i
340
- )
341
- self.mem_pool_device.transfer_per_layer(
342
- batch_operation.device_indices, flat_data, i
343
- )
344
- else:
345
- self.mem_pool_host.load_page_per_layer(
346
- batch_operation.host_indices,
347
- batch_operation.device_indices,
348
- self.mem_pool_device,
349
- i,
350
- )
351
- self.load_stream.synchronize()
352
- self.layer_done_counter.increment()
353
-
354
- self.mem_pool_host.complete_io(batch_operation.host_indices)
355
- for node_id in batch_operation.node_ids:
356
- if node_id != 0:
357
- self.ack_load_queue.put(node_id)
333
+ self.layer_done_counter.reset()
334
+ for i in range(self.mem_pool_host.layer_num):
335
+ if self.page_size == 1:
336
+ flat_data = self.mem_pool_host.get_flat_data_by_layer(
337
+ batch_operation.host_indices, i
338
+ )
339
+ self.mem_pool_device.transfer_per_layer(
340
+ batch_operation.device_indices, flat_data, i
341
+ )
342
+ else:
343
+ self.mem_pool_host.load_page_per_layer(
344
+ batch_operation.host_indices,
345
+ batch_operation.device_indices,
346
+ self.mem_pool_device,
347
+ i,
348
+ )
349
+ self.load_stream.synchronize()
350
+ self.layer_done_counter.increment()
351
+
352
+ self.mem_pool_host.complete_io(batch_operation.host_indices)
353
+ for node_id in batch_operation.node_ids:
354
+ if node_id != 0:
355
+ self.ack_load_queue.put(node_id)
358
356
 
359
357
  def write_aux_func(self, no_wait=False):
360
358
  """
361
359
  Auxiliary function to prepare the buffer for write operations.
362
360
  """
361
+ torch.cuda.set_stream(self.write_stream)
363
362
 
364
363
  def _to_op(op_):
365
364
  assert op_.device_indices.is_cuda, "Device indices should be on GPU"
@@ -370,44 +369,42 @@ class HiCacheController:
370
369
  return op_
371
370
 
372
371
  buffer = None
373
- with torch.cuda.stream(self.write_stream):
374
- while not self.stop_event.is_set():
375
- try:
376
- operation = self.write_queue.get(block=True, timeout=1)
377
- factor = (
378
- len(operation.device_indices)
379
- // self.write_buffer.max_buffer_size
380
- )
372
+ while not self.stop_event.is_set():
373
+ try:
374
+ operation = self.write_queue.get(block=True, timeout=1)
375
+ factor = (
376
+ len(operation.device_indices) // self.write_buffer.max_buffer_size
377
+ )
381
378
 
382
- if factor >= 1:
383
- if buffer is not None:
384
- _to_op(buffer)
385
- buffer = None
386
-
387
- if factor < 2:
388
- _to_op(operation)
389
- else:
390
- split_ops = operation.split(factor)
391
- for op_ in split_ops:
392
- _to_op(op_)
393
- continue
394
-
395
- if buffer is None:
396
- buffer = operation
397
- else:
398
- buffer.merge(operation)
399
- if (
400
- no_wait
401
- or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
402
- or self.write_queue.empty()
403
- or self.write_buffer.empty()
404
- ):
379
+ if factor >= 1:
380
+ if buffer is not None:
405
381
  _to_op(buffer)
406
382
  buffer = None
407
- except Empty:
383
+
384
+ if factor < 2:
385
+ _to_op(operation)
386
+ else:
387
+ split_ops = operation.split(factor)
388
+ for op_ in split_ops:
389
+ _to_op(op_)
408
390
  continue
409
- except Exception as e:
410
- logger.error(e)
391
+
392
+ if buffer is None:
393
+ buffer = operation
394
+ else:
395
+ buffer.merge(operation)
396
+ if (
397
+ no_wait
398
+ or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
399
+ or self.write_queue.empty()
400
+ or self.write_buffer.empty()
401
+ ):
402
+ _to_op(buffer)
403
+ buffer = None
404
+ except Empty:
405
+ continue
406
+ except Exception as e:
407
+ logger.error(e)
411
408
 
412
409
  def load_aux_func(self):
413
410
  """
@@ -484,19 +481,18 @@ class HiCacheController:
484
481
  aux_thread.join()
485
482
 
486
483
  def load_thread_func_buffer(self):
484
+ torch.cuda.set_stream(self.load_stream)
487
485
  aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
488
486
  aux_thread.start()
489
-
490
- with torch.cuda.stream(self.load_stream):
491
- while not self.stop_event.is_set():
492
- operation = self.load_buffer.get()
493
- if operation is None:
494
- continue
495
- self.mem_pool_device.transfer(operation.device_indices, operation.data)
496
- self.mem_pool_host.complete_io(operation.host_indices)
497
- for node_id in operation.node_ids:
498
- if node_id != 0:
499
- self.ack_load_queue.put(node_id)
487
+ while not self.stop_event.is_set():
488
+ operation = self.load_buffer.get()
489
+ if operation is None:
490
+ continue
491
+ self.mem_pool_device.transfer(operation.device_indices, operation.data)
492
+ self.mem_pool_host.complete_io(operation.host_indices)
493
+ for node_id in operation.node_ids:
494
+ if node_id != 0:
495
+ self.ack_load_queue.put(node_id)
500
496
  aux_thread.join()
501
497
 
502
498
  def evict_device(
@@ -17,13 +17,13 @@ import logging
17
17
  import multiprocessing as mp
18
18
  import signal
19
19
  import threading
20
+ import time
20
21
  from enum import Enum, auto
21
22
 
22
23
  import psutil
23
24
  import setproctitle
24
25
  import zmq
25
26
 
26
- from sglang.srt.disaggregation.utils import DisaggregationMode
27
27
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
28
28
  from sglang.srt.managers.io_struct import (
29
29
  TokenizedEmbeddingReqInput,
@@ -158,7 +158,7 @@ class DataParallelController:
158
158
  # This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
159
159
  # function in scheduler.py will kill the scheduler.
160
160
  while True:
161
- pass
161
+ time.sleep(30 * 24 * 3600)
162
162
 
163
163
  def launch_dp_attention_schedulers(self, server_args, port_args):
164
164
  self.launch_tensor_parallel_group(server_args, port_args, 0, None)
@@ -210,7 +210,7 @@ class DataParallelController:
210
210
  )
211
211
  # compute zmq ports for this dp rank
212
212
  rank_port_args = PortArgs.init_new(server_args, dp_rank)
213
- # Data parallelism resues the tensor parallelism group,
213
+ # Data parallelism reuses the tensor parallelism group,
214
214
  # so all dp ranks should use the same nccl port.
215
215
  rank_port_args.nccl_port = port_args.nccl_port
216
216
 
@@ -28,6 +28,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
28
28
  from sglang.srt.managers.io_struct import (
29
29
  BatchEmbeddingOut,
30
30
  BatchMultimodalDecodeReq,
31
+ BatchMultimodalOut,
31
32
  BatchStrOut,
32
33
  BatchTokenIDOut,
33
34
  )
@@ -60,6 +61,8 @@ class DecodeStatus:
60
61
  decode_ids: List[int]
61
62
  surr_offset: int
62
63
  read_offset: int
64
+ # Offset that's sent to tokenizer for incremental update.
65
+ sent_offset: int = 0
63
66
 
64
67
 
65
68
  class DetokenizerManager:
@@ -151,7 +154,7 @@ class DetokenizerManager:
151
154
  self.decode_status[rid] = s
152
155
  else:
153
156
  s = self.decode_status[rid]
154
- s.decode_ids = recv_obj.decode_ids[i]
157
+ s.decode_ids.extend(recv_obj.decode_ids[i])
155
158
 
156
159
  read_ids.append(
157
160
  self.trim_matched_stop(
@@ -199,13 +202,15 @@ class DetokenizerManager:
199
202
  else:
200
203
  new_text = find_printable_text(new_text)
201
204
 
202
- output_strs.append(
203
- self.trim_matched_stop(
204
- s.decoded_text + new_text,
205
- recv_obj.finished_reasons[i],
206
- recv_obj.no_stop_trim[i],
207
- )
205
+ output_str = self.trim_matched_stop(
206
+ s.decoded_text + new_text,
207
+ recv_obj.finished_reasons[i],
208
+ recv_obj.no_stop_trim[i],
208
209
  )
210
+ # Incrementally send text.
211
+ incremental_output = output_str[s.sent_offset :]
212
+ s.sent_offset = len(output_str)
213
+ output_strs.append(incremental_output)
209
214
 
210
215
  return BatchStrOut(
211
216
  rids=recv_obj.rids,
@@ -232,7 +237,15 @@ class DetokenizerManager:
232
237
  )
233
238
 
234
239
  def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
235
- raise NotImplementedError()
240
+ outputs = self.tokenizer.detokenize(recv_obj)
241
+ return BatchMultimodalOut(
242
+ rids=recv_obj.rids,
243
+ finished_reasons=recv_obj.finished_reasons,
244
+ outputs=outputs,
245
+ prompt_tokens=recv_obj.prompt_tokens,
246
+ completion_tokens=recv_obj.completion_tokens,
247
+ cached_tokens=recv_obj.cached_tokens,
248
+ )
236
249
 
237
250
 
238
251
  class LimitedCapacityDict(OrderedDict):
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """
15
- The definition of objects transfered between different
15
+ The definition of objects transferred between different
16
16
  processes (TokenizerManager, DetokenizerManager, Controller).
17
17
  """
18
18
 
@@ -790,6 +790,16 @@ class ResumeMemoryOccupationReqOutput:
790
790
  pass
791
791
 
792
792
 
793
+ @dataclass
794
+ class SlowDownReqInput:
795
+ forward_sleep_time: Optional[float]
796
+
797
+
798
+ @dataclass
799
+ class SlowDownReqOutput:
800
+ pass
801
+
802
+
793
803
  @dataclass
794
804
  class AbortReq:
795
805
  # The request id
@@ -826,6 +836,8 @@ class ProfileReqInput:
826
836
  # the caller doesn't need to run stop_profile.
827
837
  num_steps: Optional[int] = None
828
838
  activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
839
+ with_stack: Optional[bool] = None
840
+ record_shapes: Optional[bool] = None
829
841
 
830
842
 
831
843
  class ProfileReqType(Enum):
@@ -51,7 +51,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
51
51
  self, input_ids: List[int], mm_inputs: MultimodalInputs
52
52
  ) -> List[int]:
53
53
  """
54
- This function will replace the data-tokens inbetween with pad_values accordingly
54
+ This function will replace the data-tokens in between with pad_values accordingly
55
55
  """
56
56
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
57
57
  data_token_pairs = self.data_token_id_pairs
@@ -8,6 +8,7 @@ from typing import List, Optional
8
8
 
9
9
  import numpy as np
10
10
  import PIL
11
+ import torch
11
12
  from PIL import Image
12
13
  from transformers import BaseImageProcessorFast
13
14
 
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
89
90
  return_tensors="pt",
90
91
  **kwargs,
91
92
  )
93
+ if "pixel_values" in result and isinstance(
94
+ result["pixel_values"], torch.Tensor
95
+ ):
96
+ result["pixel_values"] = result["pixel_values"].to("cpu")
92
97
  return result
93
98
 
94
99
  @abstractmethod
@@ -0,0 +1,232 @@
1
+ # Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
2
+
3
+ import numpy as np
4
+ import torch
5
+ from decord import VideoReader, cpu
6
+ from numpy.distutils.cpuinfo import cpu
7
+ from PIL import Image
8
+
9
+ from sglang.srt.managers.multimodal_processors.base_processor import (
10
+ BaseMultimodalProcessor,
11
+ MultimodalSpecialTokens,
12
+ )
13
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
14
+ from sglang.srt.models.internvl import InternVLChatModel
15
+
16
+
17
+ class InternVLImageProcessor(BaseMultimodalProcessor):
18
+ models = [InternVLChatModel]
19
+
20
+ def __init__(self, hf_config, server_args, _image_processor):
21
+ super().__init__(hf_config, server_args, _image_processor)
22
+ image_size = hf_config.force_image_size or hf_config.vision_config.image_size
23
+ patch_size = hf_config.vision_config.patch_size
24
+
25
+ self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
26
+ self.IMG_START_TOKEN = "<img>"
27
+ self.IMG_END_TOKEN = "</img>"
28
+ self.IMG_TOKEN = "<image>"
29
+ self.num_image_token = int(
30
+ (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
31
+ )
32
+
33
+ tokenizer = self._processor
34
+ self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
35
+ self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
36
+ self.img_context_token_id = tokenizer.convert_tokens_to_ids(
37
+ self.IMG_CONTEXT_TOKEN
38
+ )
39
+
40
+ @staticmethod
41
+ def build_transform(input_size):
42
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_STD = (0.229, 0.224, 0.225)
44
+
45
+ def resize_image(img, size):
46
+ return img.resize((size, size), Image.Resampling.BICUBIC)
47
+
48
+ def to_tensor(img):
49
+ # Convert PIL Image to numpy array
50
+ img_array = np.array(img).astype(np.float32) / 255.0
51
+ # Convert HWC to CHW format
52
+ img_array = img_array.transpose(2, 0, 1)
53
+ return torch.from_numpy(img_array)
54
+
55
+ def normalize(tensor, mean, std):
56
+ mean = torch.tensor(mean).view(-1, 1, 1)
57
+ std = torch.tensor(std).view(-1, 1, 1)
58
+ return (tensor - mean) / std
59
+
60
+ def transform(img):
61
+ img = img.convert("RGB") if img.mode != "RGB" else img
62
+ img = resize_image(img, input_size)
63
+ tensor = to_tensor(img)
64
+ tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
65
+ return tensor
66
+
67
+ return transform
68
+
69
+ @staticmethod
70
+ def dynamic_preprocess(
71
+ image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
72
+ ):
73
+
74
+ def find_closest_aspect_ratio(
75
+ aspect_ratio, target_ratios, width, height, image_size
76
+ ):
77
+ best_ratio_diff = float("inf")
78
+ best_ratio = (1, 1)
79
+ area = width * height
80
+ for ratio in target_ratios:
81
+ target_aspect_ratio = ratio[0] / ratio[1]
82
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
83
+ if ratio_diff < best_ratio_diff:
84
+ best_ratio_diff = ratio_diff
85
+ best_ratio = ratio
86
+ elif ratio_diff == best_ratio_diff:
87
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
88
+ best_ratio = ratio
89
+ return best_ratio
90
+
91
+ orig_width, orig_height = image.size
92
+ aspect_ratio = orig_width / orig_height
93
+
94
+ # calculate the existing image aspect ratio
95
+ target_ratios = set(
96
+ (i, j)
97
+ for n in range(min_num, max_num + 1)
98
+ for i in range(1, n + 1)
99
+ for j in range(1, n + 1)
100
+ if i * j <= max_num and i * j >= min_num
101
+ )
102
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
103
+
104
+ # find the closest aspect ratio to the target
105
+ target_aspect_ratio = find_closest_aspect_ratio(
106
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
107
+ )
108
+
109
+ # calculate the target width and height
110
+ target_width = image_size * target_aspect_ratio[0]
111
+ target_height = image_size * target_aspect_ratio[1]
112
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
113
+
114
+ # resize the image
115
+ resized_img = image.resize((target_width, target_height))
116
+ processed_images = []
117
+ for i in range(blocks):
118
+ box = (
119
+ (i % (target_width // image_size)) * image_size,
120
+ (i // (target_width // image_size)) * image_size,
121
+ ((i % (target_width // image_size)) + 1) * image_size,
122
+ ((i // (target_width // image_size)) + 1) * image_size,
123
+ )
124
+ # split the image
125
+ split_img = resized_img.crop(box)
126
+ processed_images.append(split_img)
127
+ assert len(processed_images) == blocks
128
+ if use_thumbnail and len(processed_images) != 1:
129
+ thumbnail_img = image.resize((image_size, image_size))
130
+ processed_images.append(thumbnail_img)
131
+ return processed_images
132
+
133
+ @staticmethod
134
+ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
135
+ if bound:
136
+ start, end = bound[0], bound[1]
137
+ else:
138
+ start, end = -100000, 100000
139
+ start_idx = max(first_idx, round(start * fps))
140
+ end_idx = min(round(end * fps), max_frame)
141
+ seg_size = float(end_idx - start_idx) / num_segments
142
+ frame_indices = np.array(
143
+ [
144
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
145
+ for idx in range(num_segments)
146
+ ]
147
+ )
148
+ return frame_indices
149
+
150
+ @staticmethod
151
+ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
152
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
153
+ max_frame = len(vr) - 1
154
+ fps = float(vr.get_avg_fps())
155
+
156
+ pixel_values_list, num_patches_list = [], []
157
+ transform = InternVLImageProcessor.build_transform(input_size=input_size)
158
+ frame_indices = InternVLImageProcessor.get_index(
159
+ bound, fps, max_frame, first_idx=0, num_segments=num_segments
160
+ )
161
+ for frame_index in frame_indices:
162
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
163
+ img = InternVLImageProcessor.dynamic_preprocess(
164
+ img, image_size=input_size, use_thumbnail=True, max_num=max_num
165
+ )
166
+ pixel_values = [transform(tile) for tile in img]
167
+ pixel_values = torch.stack(pixel_values)
168
+ num_patches_list.append(pixel_values.shape[0])
169
+ pixel_values_list.append(pixel_values)
170
+ pixel_values = torch.cat(pixel_values_list)
171
+ return pixel_values, num_patches_list
172
+
173
+ async def process_mm_data_async(
174
+ self, image_data, input_text, request_obj, max_req_input_len, **kwargs
175
+ ):
176
+ if not image_data:
177
+ return None
178
+
179
+ base_output = self.load_mm_data(
180
+ prompt=input_text,
181
+ image_data=image_data,
182
+ multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN),
183
+ max_req_input_len=max_req_input_len,
184
+ discard_alpha_channel=True,
185
+ )
186
+
187
+ def process_image_internvl(image, input_size=448, max_num=12):
188
+ transform = InternVLImageProcessor.build_transform(input_size=input_size)
189
+ images = InternVLImageProcessor.dynamic_preprocess(
190
+ image, image_size=input_size, use_thumbnail=True, max_num=max_num
191
+ )
192
+ pixel_values = [transform(image) for image in images]
193
+ pixel_values = torch.stack(pixel_values)
194
+ return pixel_values
195
+
196
+ num_patches_list = []
197
+ pixel_values = []
198
+ # Process each input with allocated frames
199
+ for image_index, (image) in enumerate(base_output.images):
200
+ try:
201
+ # TODO: video input
202
+ raw_image = process_image_internvl(image)
203
+ pixel_value = [raw_image.to(torch.bfloat16).cuda()]
204
+ pixel_values += pixel_value
205
+ num_patches = raw_image.shape[0]
206
+ num_patches_list += [num_patches]
207
+
208
+ except FileNotFoundError as e:
209
+ print(e)
210
+ return None
211
+
212
+ pixel_values = torch.cat(pixel_values, dim=0)
213
+ items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
214
+
215
+ for idx, num_patches in enumerate(num_patches_list):
216
+ image_tokens = (
217
+ self.IMG_START_TOKEN
218
+ + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
219
+ + self.IMG_END_TOKEN
220
+ )
221
+ input_text = input_text.replace("<image>", image_tokens, 1)
222
+
223
+ tokenizer = self._processor
224
+ return {
225
+ "input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
226
+ .flatten()
227
+ .tolist(),
228
+ "mm_items": items,
229
+ "im_start_id": self.img_start_token_id,
230
+ "im_end_id": self.img_end_token_id,
231
+ "im_token_id": self.img_context_token_id,
232
+ }