sglang 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +26 -0
  3. sglang/backend/runtime_endpoint.py +18 -14
  4. sglang/bench_latency.py +34 -16
  5. sglang/global_config.py +1 -0
  6. sglang/lang/chat_template.py +41 -6
  7. sglang/lang/interpreter.py +5 -1
  8. sglang/lang/ir.py +61 -25
  9. sglang/srt/constrained/__init__.py +3 -2
  10. sglang/srt/hf_transformers_utils.py +7 -3
  11. sglang/srt/layers/extend_attention.py +2 -1
  12. sglang/srt/layers/fused_moe.py +181 -167
  13. sglang/srt/layers/logits_processor.py +55 -19
  14. sglang/srt/layers/radix_attention.py +24 -27
  15. sglang/srt/layers/token_attention.py +4 -1
  16. sglang/srt/managers/controller/infer_batch.py +2 -2
  17. sglang/srt/managers/controller/manager_single.py +1 -1
  18. sglang/srt/managers/controller/model_runner.py +27 -15
  19. sglang/srt/managers/controller/tp_worker.py +31 -14
  20. sglang/srt/managers/detokenizer_manager.py +4 -2
  21. sglang/srt/managers/io_struct.py +1 -1
  22. sglang/srt/managers/tokenizer_manager.py +14 -13
  23. sglang/srt/model_config.py +6 -0
  24. sglang/srt/models/gemma2.py +436 -0
  25. sglang/srt/models/llama2.py +3 -3
  26. sglang/srt/models/llama_classification.py +10 -7
  27. sglang/srt/models/minicpm.py +373 -0
  28. sglang/srt/models/qwen2_moe.py +454 -0
  29. sglang/srt/openai_api_adapter.py +2 -2
  30. sglang/srt/openai_protocol.py +1 -1
  31. sglang/srt/server.py +17 -8
  32. sglang/srt/server_args.py +14 -16
  33. sglang/srt/utils.py +68 -35
  34. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/METADATA +19 -13
  35. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/RECORD +38 -35
  36. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  37. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/WHEEL +0 -0
  38. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,10 @@
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
+ from flashinfer.cascade import merge_state
5
6
  from torch import nn
6
7
 
7
8
  from sglang.global_config import global_config
8
- from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
9
9
  from sglang.srt.layers.extend_attention import extend_attention_fwd
10
10
  from sglang.srt.layers.token_attention import token_attention_fwd
11
11
  from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
@@ -13,18 +13,22 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
13
13
 
14
14
  class RadixAttention(nn.Module):
15
15
  def __init__(
16
- self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
17
- layer_id: int, logit_cap: int = -1
16
+ self,
17
+ num_heads: int,
18
+ head_dim: int,
19
+ scaling: float,
20
+ num_kv_heads: int,
21
+ layer_id: int,
22
+ logit_cap: int = -1,
18
23
  ):
19
24
  super().__init__()
20
25
  self.tp_q_head_num = num_heads
21
26
  self.tp_k_head_num = num_kv_heads
22
27
  self.tp_v_head_num = num_kv_heads
23
28
  self.head_dim = head_dim
29
+ self.scaling = scaling
24
30
  self.layer_id = layer_id
25
31
 
26
- assert np.allclose(scaling, 1.0 / (head_dim**0.5))
27
-
28
32
  from sglang.srt.managers.controller.model_runner import global_server_args_dict
29
33
 
30
34
  if not global_server_args_dict.get("disable_flashinfer", False):
@@ -32,29 +36,17 @@ class RadixAttention(nn.Module):
32
36
  self.extend_forward = self.prefill_forward_flashinfer
33
37
  self.decode_forward = self.decode_forward_flashinfer
34
38
  # flashinfer now accepts float logit_cap argument
35
- self.logit_cap = logit_cap if logit_cap > 0 else 0
39
+ self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
36
40
  else:
37
41
  self.prefill_forward = self.prefill_forward_triton
38
42
  self.extend_forward = self.extend_forward_triton
39
43
  self.decode_forward = self.decode_forward_triton
40
- self.logit_cap = logit_cap
44
+ self.logit_cap = logit_cap if logit_cap is not None else 0
41
45
 
42
46
  def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
43
- o = torch.empty_like(q)
44
-
45
- context_attention_fwd(
46
- q.view(-1, self.tp_q_head_num, self.head_dim),
47
- k,
48
- v,
49
- o.view(-1, self.tp_q_head_num, self.head_dim),
50
- input_metadata.start_loc,
51
- input_metadata.seq_lens,
52
- input_metadata.max_seq_len,
53
- self.logit_cap,
54
- )
55
- self.store_kv_cache(k, v, input_metadata)
56
-
57
- return o
47
+ # In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
48
+ # See the extend_forward_xxx functions.
49
+ raise NotImplementedError()
58
50
 
59
51
  def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
60
52
  o = torch.empty_like(q)
@@ -75,7 +67,8 @@ class RadixAttention(nn.Module):
75
67
  input_metadata.extend_seq_lens,
76
68
  input_metadata.max_seq_len,
77
69
  input_metadata.max_extend_len,
78
- self.logit_cap,
70
+ sm_scale=self.scaling,
71
+ logit_cap=self.logit_cap,
79
72
  )
80
73
 
81
74
  return o
@@ -96,18 +89,19 @@ class RadixAttention(nn.Module):
96
89
  input_metadata.max_seq_len,
97
90
  input_metadata.other_kv_index,
98
91
  input_metadata.total_num_tokens,
99
- self.logit_cap,
92
+ sm_scale=self.scaling,
93
+ logit_cap=self.logit_cap,
100
94
  )
101
95
 
102
96
  return o
103
97
 
104
98
  def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
105
- self.store_kv_cache(k, v, input_metadata)
106
-
107
99
  o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
108
100
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
109
101
  k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
110
102
  v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
103
+ causal=True,
104
+ sm_scale=self.scaling,
111
105
  logits_soft_cap=self.logit_cap,
112
106
  )
113
107
 
@@ -118,12 +112,14 @@ class RadixAttention(nn.Module):
118
112
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
119
113
  input_metadata.token_to_kv_pool.kv_data[self.layer_id],
120
114
  causal=False,
115
+ sm_scale=self.scaling,
121
116
  logits_soft_cap=self.logit_cap,
122
117
  )
123
118
 
124
- from flashinfer.cascade import merge_state
125
119
  o, _ = merge_state(o1, s1, o2, s2)
126
120
 
121
+ self.store_kv_cache(k, v, input_metadata)
122
+
127
123
  if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
128
124
  torch.cuda.synchronize()
129
125
 
@@ -135,6 +131,7 @@ class RadixAttention(nn.Module):
135
131
  o = input_metadata.flashinfer_decode_wrapper.forward(
136
132
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
137
133
  input_metadata.token_to_kv_pool.kv_data[self.layer_id],
134
+ sm_scale=self.scaling,
138
135
  logits_soft_cap=self.logit_cap,
139
136
  )
140
137
 
@@ -176,6 +176,7 @@ def _token_att_m_fwd(
176
176
  B_Start_Loc,
177
177
  B_Seqlen,
178
178
  max_len_in_batch,
179
+ sm_scale,
179
180
  logit_cap,
180
181
  ):
181
182
  BLOCK = 32
@@ -183,7 +184,6 @@ def _token_att_m_fwd(
183
184
  Lq, Lk = q.shape[-1], k_buffer.shape[-1]
184
185
  assert Lq == Lk
185
186
  assert Lk in {16, 32, 64, 128, 256}
186
- sm_scale = 1.0 / (Lk**0.5)
187
187
 
188
188
  batch, head_num = B_req_idx.shape[0], q.shape[1]
189
189
 
@@ -317,6 +317,7 @@ def token_attention_fwd(
317
317
  max_len_in_batch,
318
318
  other_kv_index,
319
319
  total_num_tokens,
320
+ sm_scale=None,
320
321
  logit_cap=-1,
321
322
  att_m=None,
322
323
  ):
@@ -324,6 +325,7 @@ def token_attention_fwd(
324
325
  att_m = torch.empty(
325
326
  (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
326
327
  )
328
+ sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
327
329
 
328
330
  _token_att_m_fwd(
329
331
  q,
@@ -334,6 +336,7 @@ def token_attention_fwd(
334
336
  b_start_loc,
335
337
  b_seq_len,
336
338
  max_len_in_batch,
339
+ sm_scale,
337
340
  logit_cap,
338
341
  )
339
342
  _token_softmax_reducev_fwd(
@@ -3,7 +3,7 @@
3
3
  import warnings
4
4
  from dataclasses import dataclass
5
5
  from enum import IntEnum, auto
6
- from typing import List
6
+ from typing import List, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
@@ -31,7 +31,7 @@ class BaseFinishReason:
31
31
 
32
32
 
33
33
  class FINISH_MATCHED_TOKEN(BaseFinishReason):
34
- def __init__(self, matched: int | List[int]):
34
+ def __init__(self, matched: Union[int, List[int]]):
35
35
  super().__init__()
36
36
  self.matched = matched
37
37
 
@@ -99,4 +99,4 @@ def start_controller_process(
99
99
  except Exception:
100
100
  logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
101
101
  finally:
102
- kill_parent_process()
102
+ kill_parent_process()
@@ -127,7 +127,7 @@ class InputMetadata:
127
127
  num_qo_heads,
128
128
  num_kv_heads,
129
129
  head_dim,
130
- 1
130
+ 1,
131
131
  )
132
132
  else:
133
133
  self.flashinfer_decode_wrapper.end_forward()
@@ -140,7 +140,7 @@ class InputMetadata:
140
140
  head_dim,
141
141
  1,
142
142
  pos_encoding_mode="NONE",
143
- data_type=self.token_to_kv_pool.kv_data[0].dtype
143
+ data_type=self.token_to_kv_pool.kv_data[0].dtype,
144
144
  )
145
145
 
146
146
  def init_extend_args(self):
@@ -228,7 +228,7 @@ class InputMetadata:
228
228
  ret.init_flashinfer_args(
229
229
  model_runner.model_config.num_attention_heads // tp_size,
230
230
  model_runner.model_config.get_num_kv_heads(tp_size),
231
- model_runner.model_config.head_dim
231
+ model_runner.model_config.head_dim,
232
232
  )
233
233
 
234
234
  return ret
@@ -259,7 +259,10 @@ class ModelRunner:
259
259
  logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
260
260
  torch.cuda.set_device(self.gpu_id)
261
261
  logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
262
- monkey_patch_vllm_p2p_access_check(self.gpu_id)
262
+
263
+ if not server_args.enable_p2p_check:
264
+ monkey_patch_vllm_p2p_access_check(self.gpu_id)
265
+
263
266
  if server_args.nccl_init_addr:
264
267
  nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
265
268
  else:
@@ -269,7 +272,7 @@ class ModelRunner:
269
272
  world_size=self.tp_size,
270
273
  rank=self.tp_rank,
271
274
  local_rank=self.gpu_id,
272
- distributed_init_method=nccl_init_method
275
+ distributed_init_method=nccl_init_method,
273
276
  )
274
277
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
275
278
  total_gpu_memory = get_available_gpu_memory(
@@ -323,7 +326,7 @@ class ModelRunner:
323
326
  device_config=device_config,
324
327
  load_config=load_config,
325
328
  lora_config=None,
326
- vision_language_config=None,
329
+ multimodal_config=None,
327
330
  parallel_config=None,
328
331
  scheduler_config=None,
329
332
  cache_config=None,
@@ -341,7 +344,13 @@ class ModelRunner:
341
344
  )
342
345
  head_dim = self.model_config.head_dim
343
346
  head_num = self.model_config.get_num_kv_heads(self.tp_size)
344
- cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
347
+ cell_size = (
348
+ head_num
349
+ * head_dim
350
+ * self.model_config.num_hidden_layers
351
+ * 2
352
+ * torch._utils._element_size(self.dtype)
353
+ )
345
354
  rest_memory = available_gpu_memory - total_gpu_memory * (
346
355
  1 - self.mem_fraction_static
347
356
  )
@@ -384,33 +393,36 @@ class ModelRunner:
384
393
  def init_flash_infer(self):
385
394
  if not global_server_args_dict.get("disable_flashinfer", False):
386
395
  from flashinfer import (
387
- BatchPrefillWithRaggedKVCacheWrapper,
388
- BatchPrefillWithPagedKVCacheWrapper,
389
396
  BatchDecodeWithPagedKVCacheWrapper,
397
+ BatchPrefillWithPagedKVCacheWrapper,
398
+ BatchPrefillWithRaggedKVCacheWrapper,
390
399
  )
391
400
  from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
392
401
 
393
402
  if not _grouped_size_compiled_for_decode_kernels(
394
403
  self.model_config.num_attention_heads // self.tp_size,
395
- self.model_config.get_num_kv_heads(self.tp_size)):
404
+ self.model_config.get_num_kv_heads(self.tp_size),
405
+ ):
396
406
  use_tensor_cores = True
397
407
  else:
398
408
  use_tensor_cores = False
399
409
 
400
410
  workspace_buffers = torch.empty(
401
- 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
411
+ 2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
402
412
  )
403
- self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
404
- workspace_buffers[0], "NHD"
413
+ self.flashinfer_prefill_wrapper_ragged = (
414
+ BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
405
415
  )
406
416
  self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
407
417
  workspace_buffers[1], "NHD"
408
418
  )
409
419
  self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
410
- workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
420
+ workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
411
421
  )
412
422
  else:
413
- self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
423
+ self.flashinfer_prefill_wrapper_ragged = (
424
+ self.flashinfer_prefill_wrapper_paged
425
+ ) = None
414
426
  self.flashinfer_decode_wrapper = None
415
427
 
416
428
  @torch.inference_mode()
@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
34
34
  from sglang.srt.model_config import ModelConfig
35
35
  from sglang.srt.server_args import ModelPortArgs, ServerArgs
36
36
  from sglang.srt.utils import (
37
+ connect_rpyc_service,
37
38
  get_int_token_logit_bias,
38
39
  is_multimodal_model,
39
40
  set_random_seed,
40
41
  start_rpyc_service_process,
41
- connect_rpyc_service,
42
42
  suppress_other_loggers,
43
43
  )
44
44
  from sglang.utils import get_exception_traceback
@@ -368,9 +368,11 @@ class ModelTpServer:
368
368
  if (
369
369
  req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
370
370
  < available_size
371
- and (req.extend_input_len + new_batch_input_tokens
372
- <= self.max_prefill_tokens
373
- or len(can_run_list) == 0)
371
+ and (
372
+ req.extend_input_len + new_batch_input_tokens
373
+ <= self.max_prefill_tokens
374
+ or len(can_run_list) == 0
375
+ )
374
376
  ):
375
377
  delta = self.tree_cache.inc_lock_ref(req.last_node)
376
378
  available_size += delta
@@ -452,7 +454,9 @@ class ModelTpServer:
452
454
  next_token_ids,
453
455
  ].tolist()
454
456
  output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
455
- output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
457
+ output.normalized_prompt_logprobs = (
458
+ output.normalized_prompt_logprobs.tolist()
459
+ )
456
460
 
457
461
  next_token_ids = next_token_ids.tolist()
458
462
  else:
@@ -582,7 +586,9 @@ class ModelTpServer:
582
586
  req.check_finished()
583
587
 
584
588
  if req.return_logprob:
585
- req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
589
+ req.decode_token_logprobs.append(
590
+ (next_token_logprobs[i], next_token_id)
591
+ )
586
592
  if req.top_logprobs_num > 0:
587
593
  req.decode_top_logprobs.append(output.decode_top_logprobs[i])
588
594
 
@@ -759,16 +765,27 @@ class ModelTpClient:
759
765
  with ThreadPoolExecutor(self.tp_size) as executor:
760
766
  # Launch model processes
761
767
  if server_args.nnodes == 1:
762
- self.procs = list(executor.map(
763
- lambda args: start_rpyc_service_process(*args),
764
- [(ModelTpService, p) for p in model_port_args.model_tp_ports],
765
- ))
768
+ self.procs = list(
769
+ executor.map(
770
+ lambda args: start_rpyc_service_process(*args),
771
+ [
772
+ (ModelTpService, p)
773
+ for p in model_port_args.model_tp_ports
774
+ ],
775
+ )
776
+ )
766
777
  addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
767
778
  else:
768
- addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
769
-
770
- self.model_services = list(executor.map(
771
- lambda args: connect_rpyc_service(*args), addrs))
779
+ addrs = [
780
+ (ip, port)
781
+ for ip, port in zip(
782
+ model_port_args.model_tp_ips, model_port_args.model_tp_ports
783
+ )
784
+ ]
785
+
786
+ self.model_services = list(
787
+ executor.map(lambda args: connect_rpyc_service(*args), addrs)
788
+ )
772
789
 
773
790
  # Init model
774
791
  def init_model(i):
@@ -11,7 +11,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
11
11
  from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
12
12
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
13
13
  from sglang.srt.server_args import PortArgs, ServerArgs
14
- from sglang.utils import get_exception_traceback, graceful_registry
14
+ from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
15
15
 
16
16
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
17
17
 
@@ -57,6 +57,8 @@ class DetokenizerManager:
57
57
  output_strs = []
58
58
  for i in range(len(recv_obj.rids)):
59
59
  new_text = read_texts[i][len(surr_texts[i]) :]
60
+ if recv_obj.finished_reason[i] is None:
61
+ new_text = find_printable_text(new_text)
60
62
  output_strs.append(recv_obj.decoded_texts[i] + new_text)
61
63
 
62
64
  if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
@@ -67,7 +69,7 @@ class DetokenizerManager:
67
69
  self.send_to_tokenizer.send_pyobj(
68
70
  BatchStrOut(
69
71
  rids=recv_obj.rids,
70
- output_str=output_strs,
72
+ output_strs=output_strs,
71
73
  meta_info=recv_obj.meta_info,
72
74
  finished_reason=recv_obj.finished_reason,
73
75
  )
@@ -122,7 +122,7 @@ class BatchTokenIDOut:
122
122
  @dataclass
123
123
  class BatchStrOut:
124
124
  rids: List[str]
125
- output_str: List[str]
125
+ output_strs: List[str]
126
126
  meta_info: List[Dict]
127
127
  finished_reason: List[BaseFinishReason]
128
128
 
@@ -316,7 +316,7 @@ class TokenizerManager:
316
316
 
317
317
  recv_obj.meta_info[i]["id"] = rid
318
318
  out_dict = {
319
- "text": recv_obj.output_str[i],
319
+ "text": recv_obj.output_strs[i],
320
320
  "meta_info": recv_obj.meta_info[i],
321
321
  }
322
322
  state.out_list.append(out_dict)
@@ -333,17 +333,18 @@ class TokenizerManager:
333
333
  ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
334
334
  ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
335
335
  )
336
- if top_logprobs_num > 0:
337
- ret["meta_info"][
338
- "prefill_top_logprobs"
339
- ] = self.detokenize_top_logprobs_tokens(
340
- ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
341
- )
342
- ret["meta_info"][
343
- "decode_top_logprobs"
344
- ] = self.detokenize_top_logprobs_tokens(
345
- ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
346
- )
336
+
337
+ if top_logprobs_num > 0:
338
+ ret["meta_info"][
339
+ "prefill_top_logprobs"
340
+ ] = self.detokenize_top_logprobs_tokens(
341
+ ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
342
+ )
343
+ ret["meta_info"][
344
+ "decode_top_logprobs"
345
+ ] = self.detokenize_top_logprobs_tokens(
346
+ ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
347
+ )
347
348
  return ret
348
349
 
349
350
  def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
@@ -383,7 +384,7 @@ def get_pixel_values(
383
384
  try:
384
385
  processor = processor or global_processor
385
386
  image, image_size = load_image(image_data)
386
- if image_size != None:
387
+ if image_size is not None:
387
388
  image_hash = hash(image_data)
388
389
  pixel_values = processor.image_processor(image)["pixel_values"]
389
390
  for _ in range(len(pixel_values)):
@@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig):
115
115
  """Get the "sub" config relevant to llm for multi modal models.
116
116
  No op for pure text models.
117
117
  """
118
+ class_name = config.architectures[0]
119
+ if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
120
+ # We support non-hf version of llava models, so we do not want to
121
+ # read the wrong values from the unused default text_config.
122
+ return config
123
+
118
124
  if hasattr(config, "text_config"):
119
125
  # The code operates under the assumption that text_config should have
120
126
  # `num_attention_heads` (among others). Assert here to fail early