sglang 0.3.6.post1__py3-none-any.whl → 0.3.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +4 -8
  3. sglang/bench_one_batch_server.py +6 -5
  4. sglang/check_env.py +7 -1
  5. sglang/lang/tracer.py +1 -1
  6. sglang/launch_server.py +2 -4
  7. sglang/srt/configs/model_config.py +2 -6
  8. sglang/srt/layers/attention/flashinfer_backend.py +3 -3
  9. sglang/srt/layers/sampler.py +1 -1
  10. sglang/srt/managers/data_parallel_controller.py +7 -11
  11. sglang/srt/managers/detokenizer_manager.py +7 -6
  12. sglang/srt/managers/image_processor.py +7 -10
  13. sglang/srt/managers/io_struct.py +0 -10
  14. sglang/srt/managers/schedule_batch.py +51 -13
  15. sglang/srt/managers/scheduler.py +41 -29
  16. sglang/srt/managers/session_controller.py +15 -7
  17. sglang/srt/managers/tokenizer_manager.py +4 -33
  18. sglang/srt/managers/tp_worker_overlap_thread.py +11 -2
  19. sglang/srt/models/grok.py +11 -48
  20. sglang/srt/models/llava.py +16 -9
  21. sglang/srt/models/olmo2.py +392 -0
  22. sglang/srt/models/qwen2_vl.py +10 -3
  23. sglang/srt/openai_api/adapter.py +1 -1
  24. sglang/srt/server.py +48 -45
  25. sglang/srt/server_args.py +1 -1
  26. sglang/srt/utils.py +22 -24
  27. sglang/test/test_utils.py +21 -8
  28. sglang/utils.py +2 -2
  29. sglang/version.py +1 -1
  30. {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/METADATA +4 -2
  31. {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/RECORD +34 -36
  32. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  33. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  34. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  35. {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/LICENSE +0 -0
  36. {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/WHEEL +0 -0
  37. {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
 
16
16
  import logging
17
17
  import os
18
+ import signal
18
19
  import threading
19
20
  import time
20
21
  import warnings
@@ -23,6 +24,7 @@ from concurrent import futures
23
24
  from types import SimpleNamespace
24
25
  from typing import List, Optional
25
26
 
27
+ import psutil
26
28
  import torch
27
29
  import zmq
28
30
 
@@ -36,8 +38,6 @@ from sglang.srt.managers.io_struct import (
36
38
  BatchTokenIDOut,
37
39
  CloseSessionReqInput,
38
40
  FlushCacheReq,
39
- GetMemPoolSizeReq,
40
- GetMemPoolSizeReqOutput,
41
41
  OpenSessionReqInput,
42
42
  OpenSessionReqOutput,
43
43
  ProfileReq,
@@ -71,9 +71,9 @@ from sglang.srt.utils import (
71
71
  broadcast_pyobj,
72
72
  configure_logger,
73
73
  crash_on_warnings,
74
+ get_bool_env_var,
74
75
  get_zmq_socket,
75
- gpu_proc_affinity,
76
- kill_parent_process,
76
+ set_gpu_proc_affinity,
77
77
  set_random_seed,
78
78
  suppress_other_loggers,
79
79
  )
@@ -82,7 +82,7 @@ from sglang.utils import get_exception_traceback
82
82
  logger = logging.getLogger(__name__)
83
83
 
84
84
  # Test retract decode
85
- test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
85
+ test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
86
86
 
87
87
 
88
88
  class Scheduler:
@@ -169,6 +169,10 @@ class Scheduler:
169
169
  self.enable_overlap = False
170
170
  logger.info("Overlap scheduler is disabled for embedding models.")
171
171
 
172
+ if self.model_config.is_multimodal:
173
+ self.enable_overlap = False
174
+ logger.info("Overlap scheduler is disabled for multimodal models.")
175
+
172
176
  if self.enable_overlap:
173
177
  self.disable_jump_forward = True
174
178
 
@@ -311,6 +315,7 @@ class Scheduler:
311
315
  self.watchdog_timeout = server_args.watchdog_timeout
312
316
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
313
317
  t.start()
318
+ self.parent_process = psutil.Process().parent()
314
319
 
315
320
  # Init profiler
316
321
  if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
@@ -354,7 +359,7 @@ class Scheduler:
354
359
  self.watchdog_last_time = time.time()
355
360
  time.sleep(self.watchdog_timeout / 2)
356
361
 
357
- kill_parent_process()
362
+ self.parent_process.send_signal(signal.SIGQUIT)
358
363
 
359
364
  @torch.no_grad()
360
365
  def event_loop_normal(self):
@@ -514,10 +519,6 @@ class Scheduler:
514
519
  self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
515
520
  elif isinstance(recv_req, CloseSessionReqInput):
516
521
  self.close_session(recv_req)
517
- elif isinstance(recv_req, GetMemPoolSizeReq):
518
- self.send_to_tokenizer.send_pyobj(
519
- GetMemPoolSizeReqOutput(self.max_total_num_tokens)
520
- )
521
522
  else:
522
523
  raise ValueError(f"Invalid request: {recv_req}")
523
524
 
@@ -525,8 +526,9 @@ class Scheduler:
525
526
  self,
526
527
  recv_req: TokenizedGenerateReqInput,
527
528
  ):
529
+ # Create a new request
528
530
  if recv_req.session_id is None or recv_req.session_id not in self.sessions:
529
- # Create a new request
531
+
530
532
  if recv_req.input_embeds is not None:
531
533
  # Generate fake input_ids based on the length of input_embeds
532
534
  seq_length = len(recv_req.input_embeds)
@@ -557,24 +559,30 @@ class Scheduler:
557
559
  self.waiting_queue.append(req)
558
560
  return
559
561
 
560
- # Image inputs
562
+ # Handle image inputs
561
563
  if recv_req.image_inputs is not None:
562
- req.image_inputs = ImageInputs.from_dict(
563
- recv_req.image_inputs, self.model_config.vocab_size
564
- )
564
+ image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
565
+ # Expand a single image token into multiple dummy tokens for receiving image embeddings
565
566
  req.origin_input_ids = self.pad_input_ids_func(
566
- req.origin_input_ids_unpadded, req.image_inputs
567
+ req.origin_input_ids, image_inputs
567
568
  )
569
+ req.extend_image_inputs(image_inputs)
568
570
 
569
- if len(req.origin_input_ids) > self.max_req_input_len:
570
- req.finished_reason = FINISH_ABORT(
571
- "Image request length is longer than the KV cache pool size or "
572
- "the max context length aborting because you cannot truncate the image embeds"
571
+ if len(req.origin_input_ids) >= self.max_req_input_len:
572
+ logger.error(
573
+ "Multimodal prompt is too long after expanding multimodal tokens. "
574
+ f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
573
575
  )
576
+ req.origin_input_ids = [0]
577
+ req.image_inputs = None
574
578
  req.sampling_params.max_new_tokens = 0
579
+ req.finished_reason = FINISH_ABORT(
580
+ "Multimodal prompt is too long. Check server logs for details."
581
+ )
575
582
  self.waiting_queue.append(req)
576
583
  return
577
584
 
585
+ # Copy more attributes
578
586
  req.return_logprob = recv_req.return_logprob
579
587
  req.top_logprobs_num = recv_req.top_logprobs_num
580
588
  req.stream = recv_req.stream
@@ -1342,13 +1350,15 @@ class Scheduler:
1342
1350
 
1343
1351
  if to_del is not None:
1344
1352
  del self.waiting_queue[to_del]
1353
+ logger.debug(f"Abort queued request. {req.rid=}")
1354
+ return
1345
1355
 
1346
1356
  # Delete requests in the running batch
1347
1357
  if self.running_batch:
1348
1358
  for req in self.running_batch.reqs:
1349
1359
  if req.rid == recv_req.rid and not req.finished():
1350
- req.finished_reason = FINISH_ABORT()
1351
- self.tree_cache.cache_finished_req(req)
1360
+ logger.debug(f"Abort running request. {req.rid=}")
1361
+ req.to_abort = True
1352
1362
  break
1353
1363
 
1354
1364
  def update_weights(self, recv_req: UpdateWeightReqInput):
@@ -1404,11 +1414,12 @@ def run_scheduler_process(
1404
1414
  pipe_writer,
1405
1415
  ):
1406
1416
  # set cpu affinity to this gpu process
1407
- gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1417
+ if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1418
+ set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1408
1419
 
1409
- # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
1410
- if dp_rank is None and "DP_RANK" in os.environ:
1411
- dp_rank = int(os.environ["DP_RANK"])
1420
+ # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
1421
+ if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
1422
+ dp_rank = int(os.environ["SGLANG_DP_RANK"])
1412
1423
 
1413
1424
  if dp_rank is None:
1414
1425
  configure_logger(server_args, prefix=f" TP{tp_rank}")
@@ -1416,6 +1427,7 @@ def run_scheduler_process(
1416
1427
  configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1417
1428
 
1418
1429
  suppress_other_loggers()
1430
+ parent_process = psutil.Process().parent()
1419
1431
 
1420
1432
  try:
1421
1433
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
@@ -1427,6 +1439,6 @@ def run_scheduler_process(
1427
1439
  else:
1428
1440
  scheduler.event_loop_normal()
1429
1441
  except Exception:
1430
- msg = get_exception_traceback()
1431
- logger.error(msg)
1432
- kill_parent_process()
1442
+ traceback = get_exception_traceback()
1443
+ logger.error(f"Scheduler hit an exception: {traceback}")
1444
+ parent_process.send_signal(signal.SIGQUIT)
@@ -10,10 +10,7 @@
10
10
  # limitations under the License.
11
11
  # ==============================================================================
12
12
 
13
- import copy
14
13
  import uuid
15
- from dataclasses import dataclass
16
- from typing import Optional
17
14
 
18
15
  from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
19
16
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
@@ -41,16 +38,27 @@ class Session:
41
38
  ]
42
39
  + req.input_ids
43
40
  )
41
+ input_ids_unpadded = (
42
+ self.reqs[-1].origin_input_ids_unpadded
43
+ + self.reqs[-1].output_ids[
44
+ : self.reqs[-1].sampling_params.max_new_tokens
45
+ ]
46
+ + req.input_ids
47
+ )
44
48
  else:
45
49
  input_ids = req.input_ids
50
+ input_ids_unpadded = req.input_ids
46
51
  new_req = Req(
47
- req.rid,
48
- None,
49
- input_ids,
50
- req.sampling_params,
52
+ rid=req.rid,
53
+ origin_input_text=None,
54
+ origin_input_ids=input_ids,
55
+ origin_input_ids_unpadded=input_ids_unpadded,
56
+ sampling_params=req.sampling_params,
51
57
  lora_path=req.lora_path,
52
58
  session_id=self.session_id,
53
59
  )
60
+ if len(self.reqs) > 0:
61
+ new_req.image_inputs = self.reqs[-1].image_inputs
54
62
  new_req.tokenizer = tokenizer
55
63
  if req.session_rid is not None and len(self.reqs) == 0:
56
64
  new_req.finished_reason = FINISH_ABORT(
@@ -45,8 +45,6 @@ from sglang.srt.managers.io_struct import (
45
45
  EmbeddingReqInput,
46
46
  FlushCacheReq,
47
47
  GenerateReqInput,
48
- GetMemPoolSizeReq,
49
- GetMemPoolSizeReqOutput,
50
48
  OpenSessionReqInput,
51
49
  OpenSessionReqOutput,
52
50
  ProfileReq,
@@ -58,7 +56,7 @@ from sglang.srt.managers.io_struct import (
58
56
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
59
57
  from sglang.srt.sampling.sampling_params import SamplingParams
60
58
  from sglang.srt.server_args import PortArgs, ServerArgs
61
- from sglang.srt.utils import get_zmq_socket, kill_child_process
59
+ from sglang.srt.utils import get_zmq_socket, kill_process_tree
62
60
 
63
61
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
64
62
 
@@ -218,7 +216,8 @@ class TokenizerManager:
218
216
  input_ids = obj.input_ids
219
217
 
220
218
  if self.is_generation:
221
- image_inputs = await self.image_processor.process_images_async(
219
+ # TODO: also support getting embeddings for multimodal models
220
+ image_inputs: Dict = await self.image_processor.process_images_async(
222
221
  obj.image_data, input_text or input_ids, obj
223
222
  )
224
223
  if image_inputs and "input_ids" in image_inputs:
@@ -406,25 +405,6 @@ class TokenizerManager:
406
405
  req = ProfileReq.STOP_PROFILE
407
406
  self.send_to_scheduler.send_pyobj(req)
408
407
 
409
- async def get_memory_pool_size(self):
410
- if self.to_create_loop:
411
- self.create_handle_loop()
412
-
413
- req = GetMemPoolSizeReq()
414
-
415
- self.send_to_scheduler.send_pyobj(req)
416
- self.mem_pool_size = asyncio.Future()
417
-
418
- # FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
419
- if self.server_args.dp_size == 1:
420
- res = await self.mem_pool_size
421
- return res.size
422
- else: # self.server_args.dp_size > 1
423
- self.mem_pool_size_tmp = []
424
- res = await self.mem_pool_size
425
- ret = [r.size for r in res]
426
- return ret
427
-
428
408
  async def update_weights(
429
409
  self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
430
410
  ):
@@ -532,7 +512,7 @@ class TokenizerManager:
532
512
  else:
533
513
  break
534
514
 
535
- kill_child_process(include_self=True)
515
+ kill_process_tree(os.getpid(), include_parent=True)
536
516
  sys.exit(0)
537
517
 
538
518
  async def handle_loop(self):
@@ -552,15 +532,6 @@ class TokenizerManager:
552
532
  if len(self.model_update_tmp) == self.server_args.dp_size:
553
533
  self.model_update_result.set_result(self.model_update_tmp)
554
534
  continue
555
- elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
556
- if self.server_args.dp_size == 1:
557
- self.mem_pool_size.set_result(recv_obj)
558
- else: # self.sever_args.dp_size > 1
559
- self.mem_pool_size_tmp.append(recv_obj)
560
- # set future if the all results are received
561
- if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
562
- self.mem_pool_size.set_result(self.mem_pool_size_tmp)
563
- continue
564
535
  elif isinstance(recv_obj, OpenSessionReqOutput):
565
536
  self.session_futures[recv_obj.session_id].set_result(
566
537
  recv_obj.session_id
@@ -15,16 +15,19 @@
15
15
 
16
16
  import dataclasses
17
17
  import logging
18
+ import signal
18
19
  import threading
19
20
  from queue import Queue
20
21
  from typing import Optional
21
22
 
23
+ import psutil
22
24
  import torch
23
25
 
24
26
  from sglang.srt.managers.io_struct import UpdateWeightReqInput
25
27
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
26
28
  from sglang.srt.managers.tp_worker import TpModelWorker
27
29
  from sglang.srt.server_args import ServerArgs
30
+ from sglang.utils import get_exception_traceback
28
31
 
29
32
  logger = logging.getLogger(__name__)
30
33
 
@@ -70,6 +73,7 @@ class TpModelWorkerClient:
70
73
  target=self.forward_thread_func,
71
74
  )
72
75
  self.forward_thread.start()
76
+ self.parent_process = psutil.Process().parent()
73
77
 
74
78
  def get_worker_info(self):
75
79
  return self.worker.get_worker_info()
@@ -87,8 +91,13 @@ class TpModelWorkerClient:
87
91
  )
88
92
 
89
93
  def forward_thread_func(self):
90
- with torch.cuda.stream(self.forward_stream):
91
- self.forward_thread_func_()
94
+ try:
95
+ with torch.cuda.stream(self.forward_stream):
96
+ self.forward_thread_func_()
97
+ except Exception:
98
+ traceback = get_exception_traceback()
99
+ logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
100
+ self.parent_process.send_signal(signal.SIGQUIT)
92
101
 
93
102
  @torch.no_grad()
94
103
  def forward_thread_func_(self):
sglang/srt/models/grok.py CHANGED
@@ -16,22 +16,17 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
17
17
  """Inference-only Grok1 model."""
18
18
 
19
- import warnings
20
- from typing import Iterable, List, Optional, Tuple
19
+ from typing import Iterable, Optional, Tuple
21
20
 
22
21
  import torch
23
22
  import torch.nn.functional as F
24
23
  from torch import nn
25
24
  from transformers import PretrainedConfig
26
- from vllm.distributed import (
27
- get_tensor_model_parallel_rank,
28
- get_tensor_model_parallel_world_size,
29
- )
25
+ from vllm.distributed import get_tensor_model_parallel_world_size
30
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.model_loader.loader import DefaultModelLoader
32
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
28
 
34
- from sglang.srt.layers.fused_moe_grok import FusedMoE
29
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
35
30
  from sglang.srt.layers.layernorm import RMSNorm
36
31
  from sglang.srt.layers.linear import (
37
32
  QKVParallelLinear,
@@ -41,10 +36,12 @@ from sglang.srt.layers.linear import (
41
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
44
40
  from sglang.srt.layers.vocab_parallel_embedding import (
45
41
  ParallelLMHead,
46
42
  VocabParallelEmbedding,
47
43
  )
44
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
48
45
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
46
 
50
47
 
@@ -293,17 +290,11 @@ class Grok1ForCausalLM(nn.Module):
293
290
  super().__init__()
294
291
  self.config = config
295
292
  self.quant_config = quant_config
293
+ self.torchao_config = global_server_args_dict["torchao_config"]
296
294
  self.model = Grok1Model(config, quant_config=quant_config)
297
295
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
298
296
  self.logits_processor = LogitsProcessor(config)
299
297
 
300
- # Monkey patch _prepare_weights to load pre-sharded weights
301
- setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
302
-
303
- self.use_presharded_weights = True
304
-
305
- warnings.filterwarnings("ignore", category=FutureWarning)
306
-
307
298
  def forward(
308
299
  self,
309
300
  input_ids: torch.Tensor,
@@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module):
357
348
  continue
358
349
  name = name.replace(weight_name, param_name)
359
350
 
360
- if self.use_presharded_weights:
361
- extra_kwargs = {
362
- "use_presharded_weights": self.use_presharded_weights
363
- }
364
- else:
365
- extra_kwargs = {}
366
-
367
351
  param = params_dict[name]
368
352
  weight_loader = param.weight_loader
369
353
  weight_loader(
370
354
  param,
371
355
  loaded_weight,
372
- weight_name,
356
+ name,
373
357
  shard_id=shard_id,
374
358
  expert_id=expert_id,
375
- **extra_kwargs,
376
359
  )
377
360
  break
378
361
  else:
379
362
  # Skip loading extra bias for GPTQ models.
380
363
  if name.endswith(".bias") and name not in params_dict:
381
364
  continue
365
+ # Skip loading kv_scale from ckpts towards new design.
366
+ if name.endswith(".kv_scale") and name not in params_dict:
367
+ continue
382
368
  if name is None:
383
369
  continue
384
370
 
@@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module):
388
374
  )
389
375
  weight_loader(param, loaded_weight)
390
376
 
391
-
392
- old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
393
-
394
-
395
- def _prepare_presharded_weights(
396
- self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
397
- ) -> Tuple[str, List[str], bool]:
398
- import glob
399
- import os
400
-
401
- if get_tensor_model_parallel_world_size() == 1:
402
- return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
403
-
404
- tp_rank = get_tensor_model_parallel_rank()
405
- allow_patterns = [f"*-{tp_rank:03d}.bin"]
406
-
407
- hf_folder = model_name_or_path
408
-
409
- hf_weights_files: List[str] = []
410
- for pattern in allow_patterns:
411
- hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
412
- use_safetensors = False
413
-
414
- return hf_folder, hf_weights_files, use_safetensors
377
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
415
378
 
416
379
 
417
380
  class Grok1ModelForCausalLM(Grok1ForCausalLM):
@@ -49,9 +49,15 @@ class LlavaBaseForCausalLM(nn.Module):
49
49
  image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
50
50
 
51
51
  # hardcode for spatial_unpad + anyres
52
- image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
52
+ if image_inputs.modalities is not None and (
53
+ "multi-images" in image_inputs.modalities
54
+ or "video" in image_inputs.modalities
55
+ ):
56
+ image_aspect_ratio = "pad"
57
+ else:
58
+ image_aspect_ratio = "anyres"
53
59
  offset_list = []
54
- for image_s in image_sizes:
60
+ for image_idx, image_s in enumerate(image_sizes):
55
61
  if len(image_sizes) > 16:
56
62
  # 2x2 pooling with stride 2
57
63
  new_image_feature_len = (
@@ -86,10 +92,6 @@ class LlavaBaseForCausalLM(nn.Module):
86
92
  new_w = int(new_w // times)
87
93
  new_image_feature_len += new_h * (new_w + 1)
88
94
 
89
- pad_ids = pad_values * (
90
- (new_image_feature_len + len(pad_values)) // len(pad_values)
91
- )
92
- # print("calculated new_image_feature_len: ", new_image_feature_len)
93
95
  try:
94
96
  offset = input_ids.index(self.config.image_token_index)
95
97
  except ValueError:
@@ -97,7 +99,7 @@ class LlavaBaseForCausalLM(nn.Module):
97
99
  # old_len + pad_len - 1, because we need to remove image_token_id
98
100
  input_ids = (
99
101
  input_ids[:offset]
100
- + pad_ids[:new_image_feature_len]
102
+ + [pad_values[image_idx]] * new_image_feature_len
101
103
  + input_ids[offset + 1 :]
102
104
  )
103
105
  offset_list.append(offset)
@@ -132,7 +134,6 @@ class LlavaBaseForCausalLM(nn.Module):
132
134
  image_inputs = forward_batch.image_inputs
133
135
 
134
136
  if forward_batch.forward_mode.is_extend():
135
- bs = forward_batch.batch_size
136
137
  # Got List[List[str]] extend it to List[str]
137
138
  # The length of the List should be equal to batch size
138
139
  modalities_list = []
@@ -140,11 +141,16 @@ class LlavaBaseForCausalLM(nn.Module):
140
141
  for im in image_inputs:
141
142
  if im and im.modalities is not None:
142
143
  modalities_list.extend(im.modalities)
143
- if im and im.image_offsets is not None:
144
+ if im and im.image_offsets:
144
145
  max_image_offset.append(max(im.image_offsets))
145
146
  else:
146
147
  max_image_offset.append(-1)
147
148
 
149
+ # Clamp input ids. This is because the input_ids for the image tokens are
150
+ # filled with the hash values of the image for the prefix matching in the radix attention.
151
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
152
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
153
+
148
154
  # Embed text inputs
149
155
  input_embeds = self.language_model.model.embed_tokens(input_ids)
150
156
 
@@ -152,6 +158,7 @@ class LlavaBaseForCausalLM(nn.Module):
152
158
  need_vision = start_positions <= np.array(max_image_offset)
153
159
 
154
160
  if need_vision.any():
161
+ bs = forward_batch.batch_size
155
162
  pixel_values = [
156
163
  image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
157
164
  ]