sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -268,98 +268,97 @@ class HiCacheController:
|
|
268
268
|
"""
|
269
269
|
Directly write through KV caches to host memory without buffering.
|
270
270
|
"""
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
271
|
+
torch.cuda.set_stream(self.write_stream)
|
272
|
+
while not self.stop_event.is_set():
|
273
|
+
try:
|
274
|
+
operation = self.write_queue.get(block=True, timeout=1)
|
275
|
+
self.mem_pool_host.write_page_all_layers(
|
276
|
+
operation.host_indices,
|
277
|
+
operation.device_indices,
|
278
|
+
self.mem_pool_device,
|
279
|
+
)
|
280
|
+
self.write_stream.synchronize()
|
281
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
282
|
+
for node_id in operation.node_ids:
|
283
|
+
if node_id != 0:
|
284
|
+
self.ack_write_queue.put(node_id)
|
285
|
+
except Empty:
|
286
|
+
continue
|
287
|
+
except Exception as e:
|
288
|
+
logger.error(e)
|
289
289
|
|
290
290
|
def load_thread_func_direct(self):
|
291
291
|
"""
|
292
292
|
Directly load KV caches from host memory to device memory without buffering.
|
293
293
|
"""
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
except Exception as e:
|
312
|
-
logger.error(e)
|
294
|
+
torch.cuda.set_stream(self.load_stream)
|
295
|
+
while not self.stop_event.is_set():
|
296
|
+
try:
|
297
|
+
operation = self.load_queue.get(block=True, timeout=1)
|
298
|
+
# time.sleep(18e-6 * len(operation.host_indices))
|
299
|
+
operation.data = self.mem_pool_host.get_flat_data(
|
300
|
+
operation.host_indices
|
301
|
+
)
|
302
|
+
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
303
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
304
|
+
for node_id in operation.node_ids:
|
305
|
+
if node_id != 0:
|
306
|
+
self.ack_load_queue.put(node_id)
|
307
|
+
except Empty:
|
308
|
+
continue
|
309
|
+
except Exception as e:
|
310
|
+
logger.error(e)
|
313
311
|
|
314
312
|
def load_thread_func_layer_by_layer(self):
|
315
313
|
"""
|
316
314
|
Load KV caches from host memory to device memory layer by layer.
|
317
315
|
"""
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
316
|
+
torch.cuda.set_stream(self.load_stream)
|
317
|
+
while not self.stop_event.is_set():
|
318
|
+
self.load_cache_event.wait(timeout=1)
|
319
|
+
if not self.load_cache_event.is_set():
|
320
|
+
continue
|
321
|
+
self.load_cache_event.clear()
|
324
322
|
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
if batch_operation is None:
|
329
|
-
batch_operation = op
|
330
|
-
else:
|
331
|
-
batch_operation.merge(op)
|
323
|
+
batch_operation = None
|
324
|
+
while self.load_queue.qsize() > 0:
|
325
|
+
op = self.load_queue.get(block=True)
|
332
326
|
if batch_operation is None:
|
333
|
-
|
327
|
+
batch_operation = op
|
328
|
+
else:
|
329
|
+
batch_operation.merge(op)
|
330
|
+
if batch_operation is None:
|
331
|
+
continue
|
334
332
|
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
333
|
+
self.layer_done_counter.reset()
|
334
|
+
for i in range(self.mem_pool_host.layer_num):
|
335
|
+
if self.page_size == 1:
|
336
|
+
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
337
|
+
batch_operation.host_indices, i
|
338
|
+
)
|
339
|
+
self.mem_pool_device.transfer_per_layer(
|
340
|
+
batch_operation.device_indices, flat_data, i
|
341
|
+
)
|
342
|
+
else:
|
343
|
+
self.mem_pool_host.load_page_per_layer(
|
344
|
+
batch_operation.host_indices,
|
345
|
+
batch_operation.device_indices,
|
346
|
+
self.mem_pool_device,
|
347
|
+
i,
|
348
|
+
)
|
349
|
+
self.load_stream.synchronize()
|
350
|
+
self.layer_done_counter.increment()
|
351
|
+
|
352
|
+
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
353
|
+
for node_id in batch_operation.node_ids:
|
354
|
+
if node_id != 0:
|
355
|
+
self.ack_load_queue.put(node_id)
|
358
356
|
|
359
357
|
def write_aux_func(self, no_wait=False):
|
360
358
|
"""
|
361
359
|
Auxiliary function to prepare the buffer for write operations.
|
362
360
|
"""
|
361
|
+
torch.cuda.set_stream(self.write_stream)
|
363
362
|
|
364
363
|
def _to_op(op_):
|
365
364
|
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
@@ -370,44 +369,42 @@ class HiCacheController:
|
|
370
369
|
return op_
|
371
370
|
|
372
371
|
buffer = None
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
// self.write_buffer.max_buffer_size
|
380
|
-
)
|
372
|
+
while not self.stop_event.is_set():
|
373
|
+
try:
|
374
|
+
operation = self.write_queue.get(block=True, timeout=1)
|
375
|
+
factor = (
|
376
|
+
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
377
|
+
)
|
381
378
|
|
382
|
-
|
383
|
-
|
384
|
-
_to_op(buffer)
|
385
|
-
buffer = None
|
386
|
-
|
387
|
-
if factor < 2:
|
388
|
-
_to_op(operation)
|
389
|
-
else:
|
390
|
-
split_ops = operation.split(factor)
|
391
|
-
for op_ in split_ops:
|
392
|
-
_to_op(op_)
|
393
|
-
continue
|
394
|
-
|
395
|
-
if buffer is None:
|
396
|
-
buffer = operation
|
397
|
-
else:
|
398
|
-
buffer.merge(operation)
|
399
|
-
if (
|
400
|
-
no_wait
|
401
|
-
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
402
|
-
or self.write_queue.empty()
|
403
|
-
or self.write_buffer.empty()
|
404
|
-
):
|
379
|
+
if factor >= 1:
|
380
|
+
if buffer is not None:
|
405
381
|
_to_op(buffer)
|
406
382
|
buffer = None
|
407
|
-
|
383
|
+
|
384
|
+
if factor < 2:
|
385
|
+
_to_op(operation)
|
386
|
+
else:
|
387
|
+
split_ops = operation.split(factor)
|
388
|
+
for op_ in split_ops:
|
389
|
+
_to_op(op_)
|
408
390
|
continue
|
409
|
-
|
410
|
-
|
391
|
+
|
392
|
+
if buffer is None:
|
393
|
+
buffer = operation
|
394
|
+
else:
|
395
|
+
buffer.merge(operation)
|
396
|
+
if (
|
397
|
+
no_wait
|
398
|
+
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
399
|
+
or self.write_queue.empty()
|
400
|
+
or self.write_buffer.empty()
|
401
|
+
):
|
402
|
+
_to_op(buffer)
|
403
|
+
buffer = None
|
404
|
+
except Empty:
|
405
|
+
continue
|
406
|
+
except Exception as e:
|
407
|
+
logger.error(e)
|
411
408
|
|
412
409
|
def load_aux_func(self):
|
413
410
|
"""
|
@@ -484,19 +481,18 @@ class HiCacheController:
|
|
484
481
|
aux_thread.join()
|
485
482
|
|
486
483
|
def load_thread_func_buffer(self):
|
484
|
+
torch.cuda.set_stream(self.load_stream)
|
487
485
|
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
488
486
|
aux_thread.start()
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
if node_id != 0:
|
499
|
-
self.ack_load_queue.put(node_id)
|
487
|
+
while not self.stop_event.is_set():
|
488
|
+
operation = self.load_buffer.get()
|
489
|
+
if operation is None:
|
490
|
+
continue
|
491
|
+
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
492
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
493
|
+
for node_id in operation.node_ids:
|
494
|
+
if node_id != 0:
|
495
|
+
self.ack_load_queue.put(node_id)
|
500
496
|
aux_thread.join()
|
501
497
|
|
502
498
|
def evict_device(
|
@@ -17,13 +17,13 @@ import logging
|
|
17
17
|
import multiprocessing as mp
|
18
18
|
import signal
|
19
19
|
import threading
|
20
|
+
import time
|
20
21
|
from enum import Enum, auto
|
21
22
|
|
22
23
|
import psutil
|
23
24
|
import setproctitle
|
24
25
|
import zmq
|
25
26
|
|
26
|
-
from sglang.srt.disaggregation.utils import DisaggregationMode
|
27
27
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
28
28
|
from sglang.srt.managers.io_struct import (
|
29
29
|
TokenizedEmbeddingReqInput,
|
@@ -158,7 +158,7 @@ class DataParallelController:
|
|
158
158
|
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
|
159
159
|
# function in scheduler.py will kill the scheduler.
|
160
160
|
while True:
|
161
|
-
|
161
|
+
time.sleep(30 * 24 * 3600)
|
162
162
|
|
163
163
|
def launch_dp_attention_schedulers(self, server_args, port_args):
|
164
164
|
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
@@ -210,7 +210,7 @@ class DataParallelController:
|
|
210
210
|
)
|
211
211
|
# compute zmq ports for this dp rank
|
212
212
|
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
213
|
-
# Data parallelism
|
213
|
+
# Data parallelism reuses the tensor parallelism group,
|
214
214
|
# so all dp ranks should use the same nccl port.
|
215
215
|
rank_port_args.nccl_port = port_args.nccl_port
|
216
216
|
|
@@ -28,6 +28,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
28
28
|
from sglang.srt.managers.io_struct import (
|
29
29
|
BatchEmbeddingOut,
|
30
30
|
BatchMultimodalDecodeReq,
|
31
|
+
BatchMultimodalOut,
|
31
32
|
BatchStrOut,
|
32
33
|
BatchTokenIDOut,
|
33
34
|
)
|
@@ -60,6 +61,8 @@ class DecodeStatus:
|
|
60
61
|
decode_ids: List[int]
|
61
62
|
surr_offset: int
|
62
63
|
read_offset: int
|
64
|
+
# Offset that's sent to tokenizer for incremental update.
|
65
|
+
sent_offset: int = 0
|
63
66
|
|
64
67
|
|
65
68
|
class DetokenizerManager:
|
@@ -151,7 +154,7 @@ class DetokenizerManager:
|
|
151
154
|
self.decode_status[rid] = s
|
152
155
|
else:
|
153
156
|
s = self.decode_status[rid]
|
154
|
-
s.decode_ids
|
157
|
+
s.decode_ids.extend(recv_obj.decode_ids[i])
|
155
158
|
|
156
159
|
read_ids.append(
|
157
160
|
self.trim_matched_stop(
|
@@ -199,13 +202,15 @@ class DetokenizerManager:
|
|
199
202
|
else:
|
200
203
|
new_text = find_printable_text(new_text)
|
201
204
|
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
recv_obj.no_stop_trim[i],
|
207
|
-
)
|
205
|
+
output_str = self.trim_matched_stop(
|
206
|
+
s.decoded_text + new_text,
|
207
|
+
recv_obj.finished_reasons[i],
|
208
|
+
recv_obj.no_stop_trim[i],
|
208
209
|
)
|
210
|
+
# Incrementally send text.
|
211
|
+
incremental_output = output_str[s.sent_offset :]
|
212
|
+
s.sent_offset = len(output_str)
|
213
|
+
output_strs.append(incremental_output)
|
209
214
|
|
210
215
|
return BatchStrOut(
|
211
216
|
rids=recv_obj.rids,
|
@@ -232,7 +237,15 @@ class DetokenizerManager:
|
|
232
237
|
)
|
233
238
|
|
234
239
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
235
|
-
|
240
|
+
outputs = self.tokenizer.detokenize(recv_obj)
|
241
|
+
return BatchMultimodalOut(
|
242
|
+
rids=recv_obj.rids,
|
243
|
+
finished_reasons=recv_obj.finished_reasons,
|
244
|
+
outputs=outputs,
|
245
|
+
prompt_tokens=recv_obj.prompt_tokens,
|
246
|
+
completion_tokens=recv_obj.completion_tokens,
|
247
|
+
cached_tokens=recv_obj.cached_tokens,
|
248
|
+
)
|
236
249
|
|
237
250
|
|
238
251
|
class LimitedCapacityDict(OrderedDict):
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""
|
15
|
-
The definition of objects
|
15
|
+
The definition of objects transferred between different
|
16
16
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
17
17
|
"""
|
18
18
|
|
@@ -790,6 +790,16 @@ class ResumeMemoryOccupationReqOutput:
|
|
790
790
|
pass
|
791
791
|
|
792
792
|
|
793
|
+
@dataclass
|
794
|
+
class SlowDownReqInput:
|
795
|
+
forward_sleep_time: Optional[float]
|
796
|
+
|
797
|
+
|
798
|
+
@dataclass
|
799
|
+
class SlowDownReqOutput:
|
800
|
+
pass
|
801
|
+
|
802
|
+
|
793
803
|
@dataclass
|
794
804
|
class AbortReq:
|
795
805
|
# The request id
|
@@ -826,6 +836,8 @@ class ProfileReqInput:
|
|
826
836
|
# the caller doesn't need to run stop_profile.
|
827
837
|
num_steps: Optional[int] = None
|
828
838
|
activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
|
839
|
+
with_stack: Optional[bool] = None
|
840
|
+
record_shapes: Optional[bool] = None
|
829
841
|
|
830
842
|
|
831
843
|
class ProfileReqType(Enum):
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -51,7 +51,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
51
51
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
52
52
|
) -> List[int]:
|
53
53
|
"""
|
54
|
-
This function will replace the data-tokens
|
54
|
+
This function will replace the data-tokens in between with pad_values accordingly
|
55
55
|
"""
|
56
56
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
57
57
|
data_token_pairs = self.data_token_id_pairs
|
@@ -8,6 +8,7 @@ from typing import List, Optional
|
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import PIL
|
11
|
+
import torch
|
11
12
|
from PIL import Image
|
12
13
|
from transformers import BaseImageProcessorFast
|
13
14
|
|
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
|
|
89
90
|
return_tensors="pt",
|
90
91
|
**kwargs,
|
91
92
|
)
|
93
|
+
if "pixel_values" in result and isinstance(
|
94
|
+
result["pixel_values"], torch.Tensor
|
95
|
+
):
|
96
|
+
result["pixel_values"] = result["pixel_values"].to("cpu")
|
92
97
|
return result
|
93
98
|
|
94
99
|
@abstractmethod
|
@@ -0,0 +1,232 @@
|
|
1
|
+
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
from decord import VideoReader, cpu
|
6
|
+
from numpy.distutils.cpuinfo import cpu
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
10
|
+
BaseMultimodalProcessor,
|
11
|
+
MultimodalSpecialTokens,
|
12
|
+
)
|
13
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
14
|
+
from sglang.srt.models.internvl import InternVLChatModel
|
15
|
+
|
16
|
+
|
17
|
+
class InternVLImageProcessor(BaseMultimodalProcessor):
|
18
|
+
models = [InternVLChatModel]
|
19
|
+
|
20
|
+
def __init__(self, hf_config, server_args, _image_processor):
|
21
|
+
super().__init__(hf_config, server_args, _image_processor)
|
22
|
+
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
|
23
|
+
patch_size = hf_config.vision_config.patch_size
|
24
|
+
|
25
|
+
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
26
|
+
self.IMG_START_TOKEN = "<img>"
|
27
|
+
self.IMG_END_TOKEN = "</img>"
|
28
|
+
self.IMG_TOKEN = "<image>"
|
29
|
+
self.num_image_token = int(
|
30
|
+
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
|
31
|
+
)
|
32
|
+
|
33
|
+
tokenizer = self._processor
|
34
|
+
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
|
35
|
+
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
|
36
|
+
self.img_context_token_id = tokenizer.convert_tokens_to_ids(
|
37
|
+
self.IMG_CONTEXT_TOKEN
|
38
|
+
)
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def build_transform(input_size):
|
42
|
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
43
|
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
44
|
+
|
45
|
+
def resize_image(img, size):
|
46
|
+
return img.resize((size, size), Image.Resampling.BICUBIC)
|
47
|
+
|
48
|
+
def to_tensor(img):
|
49
|
+
# Convert PIL Image to numpy array
|
50
|
+
img_array = np.array(img).astype(np.float32) / 255.0
|
51
|
+
# Convert HWC to CHW format
|
52
|
+
img_array = img_array.transpose(2, 0, 1)
|
53
|
+
return torch.from_numpy(img_array)
|
54
|
+
|
55
|
+
def normalize(tensor, mean, std):
|
56
|
+
mean = torch.tensor(mean).view(-1, 1, 1)
|
57
|
+
std = torch.tensor(std).view(-1, 1, 1)
|
58
|
+
return (tensor - mean) / std
|
59
|
+
|
60
|
+
def transform(img):
|
61
|
+
img = img.convert("RGB") if img.mode != "RGB" else img
|
62
|
+
img = resize_image(img, input_size)
|
63
|
+
tensor = to_tensor(img)
|
64
|
+
tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
|
65
|
+
return tensor
|
66
|
+
|
67
|
+
return transform
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def dynamic_preprocess(
|
71
|
+
image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
|
72
|
+
):
|
73
|
+
|
74
|
+
def find_closest_aspect_ratio(
|
75
|
+
aspect_ratio, target_ratios, width, height, image_size
|
76
|
+
):
|
77
|
+
best_ratio_diff = float("inf")
|
78
|
+
best_ratio = (1, 1)
|
79
|
+
area = width * height
|
80
|
+
for ratio in target_ratios:
|
81
|
+
target_aspect_ratio = ratio[0] / ratio[1]
|
82
|
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
83
|
+
if ratio_diff < best_ratio_diff:
|
84
|
+
best_ratio_diff = ratio_diff
|
85
|
+
best_ratio = ratio
|
86
|
+
elif ratio_diff == best_ratio_diff:
|
87
|
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
88
|
+
best_ratio = ratio
|
89
|
+
return best_ratio
|
90
|
+
|
91
|
+
orig_width, orig_height = image.size
|
92
|
+
aspect_ratio = orig_width / orig_height
|
93
|
+
|
94
|
+
# calculate the existing image aspect ratio
|
95
|
+
target_ratios = set(
|
96
|
+
(i, j)
|
97
|
+
for n in range(min_num, max_num + 1)
|
98
|
+
for i in range(1, n + 1)
|
99
|
+
for j in range(1, n + 1)
|
100
|
+
if i * j <= max_num and i * j >= min_num
|
101
|
+
)
|
102
|
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
103
|
+
|
104
|
+
# find the closest aspect ratio to the target
|
105
|
+
target_aspect_ratio = find_closest_aspect_ratio(
|
106
|
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
107
|
+
)
|
108
|
+
|
109
|
+
# calculate the target width and height
|
110
|
+
target_width = image_size * target_aspect_ratio[0]
|
111
|
+
target_height = image_size * target_aspect_ratio[1]
|
112
|
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
113
|
+
|
114
|
+
# resize the image
|
115
|
+
resized_img = image.resize((target_width, target_height))
|
116
|
+
processed_images = []
|
117
|
+
for i in range(blocks):
|
118
|
+
box = (
|
119
|
+
(i % (target_width // image_size)) * image_size,
|
120
|
+
(i // (target_width // image_size)) * image_size,
|
121
|
+
((i % (target_width // image_size)) + 1) * image_size,
|
122
|
+
((i // (target_width // image_size)) + 1) * image_size,
|
123
|
+
)
|
124
|
+
# split the image
|
125
|
+
split_img = resized_img.crop(box)
|
126
|
+
processed_images.append(split_img)
|
127
|
+
assert len(processed_images) == blocks
|
128
|
+
if use_thumbnail and len(processed_images) != 1:
|
129
|
+
thumbnail_img = image.resize((image_size, image_size))
|
130
|
+
processed_images.append(thumbnail_img)
|
131
|
+
return processed_images
|
132
|
+
|
133
|
+
@staticmethod
|
134
|
+
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
135
|
+
if bound:
|
136
|
+
start, end = bound[0], bound[1]
|
137
|
+
else:
|
138
|
+
start, end = -100000, 100000
|
139
|
+
start_idx = max(first_idx, round(start * fps))
|
140
|
+
end_idx = min(round(end * fps), max_frame)
|
141
|
+
seg_size = float(end_idx - start_idx) / num_segments
|
142
|
+
frame_indices = np.array(
|
143
|
+
[
|
144
|
+
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
145
|
+
for idx in range(num_segments)
|
146
|
+
]
|
147
|
+
)
|
148
|
+
return frame_indices
|
149
|
+
|
150
|
+
@staticmethod
|
151
|
+
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
152
|
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
153
|
+
max_frame = len(vr) - 1
|
154
|
+
fps = float(vr.get_avg_fps())
|
155
|
+
|
156
|
+
pixel_values_list, num_patches_list = [], []
|
157
|
+
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
158
|
+
frame_indices = InternVLImageProcessor.get_index(
|
159
|
+
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
160
|
+
)
|
161
|
+
for frame_index in frame_indices:
|
162
|
+
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
|
163
|
+
img = InternVLImageProcessor.dynamic_preprocess(
|
164
|
+
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
165
|
+
)
|
166
|
+
pixel_values = [transform(tile) for tile in img]
|
167
|
+
pixel_values = torch.stack(pixel_values)
|
168
|
+
num_patches_list.append(pixel_values.shape[0])
|
169
|
+
pixel_values_list.append(pixel_values)
|
170
|
+
pixel_values = torch.cat(pixel_values_list)
|
171
|
+
return pixel_values, num_patches_list
|
172
|
+
|
173
|
+
async def process_mm_data_async(
|
174
|
+
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
|
175
|
+
):
|
176
|
+
if not image_data:
|
177
|
+
return None
|
178
|
+
|
179
|
+
base_output = self.load_mm_data(
|
180
|
+
prompt=input_text,
|
181
|
+
image_data=image_data,
|
182
|
+
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN),
|
183
|
+
max_req_input_len=max_req_input_len,
|
184
|
+
discard_alpha_channel=True,
|
185
|
+
)
|
186
|
+
|
187
|
+
def process_image_internvl(image, input_size=448, max_num=12):
|
188
|
+
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
189
|
+
images = InternVLImageProcessor.dynamic_preprocess(
|
190
|
+
image, image_size=input_size, use_thumbnail=True, max_num=max_num
|
191
|
+
)
|
192
|
+
pixel_values = [transform(image) for image in images]
|
193
|
+
pixel_values = torch.stack(pixel_values)
|
194
|
+
return pixel_values
|
195
|
+
|
196
|
+
num_patches_list = []
|
197
|
+
pixel_values = []
|
198
|
+
# Process each input with allocated frames
|
199
|
+
for image_index, (image) in enumerate(base_output.images):
|
200
|
+
try:
|
201
|
+
# TODO: video input
|
202
|
+
raw_image = process_image_internvl(image)
|
203
|
+
pixel_value = [raw_image.to(torch.bfloat16).cuda()]
|
204
|
+
pixel_values += pixel_value
|
205
|
+
num_patches = raw_image.shape[0]
|
206
|
+
num_patches_list += [num_patches]
|
207
|
+
|
208
|
+
except FileNotFoundError as e:
|
209
|
+
print(e)
|
210
|
+
return None
|
211
|
+
|
212
|
+
pixel_values = torch.cat(pixel_values, dim=0)
|
213
|
+
items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
|
214
|
+
|
215
|
+
for idx, num_patches in enumerate(num_patches_list):
|
216
|
+
image_tokens = (
|
217
|
+
self.IMG_START_TOKEN
|
218
|
+
+ self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
|
219
|
+
+ self.IMG_END_TOKEN
|
220
|
+
)
|
221
|
+
input_text = input_text.replace("<image>", image_tokens, 1)
|
222
|
+
|
223
|
+
tokenizer = self._processor
|
224
|
+
return {
|
225
|
+
"input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
|
226
|
+
.flatten()
|
227
|
+
.tolist(),
|
228
|
+
"mm_items": items,
|
229
|
+
"im_start_id": self.img_start_token_id,
|
230
|
+
"im_end_id": self.img_end_token_id,
|
231
|
+
"im_token_id": self.img_context_token_id,
|
232
|
+
}
|