sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,18 @@
1
1
  # Adapted from
2
- # https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
3
3
  """Inference-only LLaMA model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, Optional, Tuple
4
+
5
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
6
 
6
7
  import torch
8
+ import tqdm
7
9
  from torch import nn
8
10
  from transformers import LlamaConfig
9
- from vllm.distributed import get_tensor_model_parallel_world_size
11
+ from vllm.config import CacheConfig
12
+ from vllm.distributed import (
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ )
10
16
  from vllm.model_executor.layers.activation import SiluAndMul
11
17
  from vllm.model_executor.layers.layernorm import RMSNorm
12
18
  from vllm.model_executor.layers.linear import (
@@ -20,11 +26,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
20
26
  ParallelLMHead,
21
27
  VocabParallelEmbedding,
22
28
  )
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
30
 
24
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
25
32
  from sglang.srt.layers.radix_attention import RadixAttention
26
- from sglang.srt.managers.router.model_runner import InputMetadata
27
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
33
+ from sglang.srt.managers.controller.model_runner import InputMetadata
28
34
 
29
35
 
30
36
  class LlamaMLP(nn.Module):
@@ -71,6 +77,7 @@ class LlamaAttention(nn.Module):
71
77
  layer_id: int = 0,
72
78
  rope_theta: float = 10000,
73
79
  rope_scaling: Optional[Dict[str, Any]] = None,
80
+ rope_is_neox_style: bool = True,
74
81
  max_position_embeddings: int = 8192,
75
82
  quant_config: Optional[QuantizationConfig] = None,
76
83
  ) -> None:
@@ -118,6 +125,7 @@ class LlamaAttention(nn.Module):
118
125
  max_position=max_position_embeddings,
119
126
  base=rope_theta,
120
127
  rope_scaling=rope_scaling,
128
+ is_neox_style=rope_is_neox_style,
121
129
  )
122
130
  self.attn = RadixAttention(
123
131
  self.num_heads,
@@ -152,6 +160,13 @@ class LlamaDecoderLayer(nn.Module):
152
160
  self.hidden_size = config.hidden_size
153
161
  rope_theta = getattr(config, "rope_theta", 10000)
154
162
  rope_scaling = getattr(config, "rope_scaling", None)
163
+ if rope_scaling is not None and getattr(
164
+ config, "original_max_position_embeddings", None
165
+ ):
166
+ rope_scaling["original_max_position_embeddings"] = (
167
+ config.original_max_position_embeddings
168
+ )
169
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
155
170
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
156
171
  self.self_attn = LlamaAttention(
157
172
  hidden_size=self.hidden_size,
@@ -160,6 +175,7 @@ class LlamaDecoderLayer(nn.Module):
160
175
  layer_id=layer_id,
161
176
  rope_theta=rope_theta,
162
177
  rope_scaling=rope_scaling,
178
+ rope_is_neox_style=rope_is_neox_style,
163
179
  max_position_embeddings=max_position_embeddings,
164
180
  quant_config=quant_config,
165
181
  )
@@ -250,6 +266,7 @@ class LlamaForCausalLM(nn.Module):
250
266
  self,
251
267
  config: LlamaConfig,
252
268
  quant_config: Optional[QuantizationConfig] = None,
269
+ cache_config: Optional[CacheConfig] = None,
253
270
  ) -> None:
254
271
  super().__init__()
255
272
  self.config = config
@@ -270,13 +287,7 @@ class LlamaForCausalLM(nn.Module):
270
287
  input_ids, hidden_states, self.lm_head.weight, input_metadata
271
288
  )
272
289
 
273
- def load_weights(
274
- self,
275
- model_name_or_path: str,
276
- cache_dir: Optional[str] = None,
277
- load_format: str = "auto",
278
- revision: Optional[str] = None,
279
- ):
290
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
280
291
  stacked_params_mapping = [
281
292
  # (param_name, shard_name, shard_id)
282
293
  ("qkv_proj", "q_proj", "q"),
@@ -286,9 +297,9 @@ class LlamaForCausalLM(nn.Module):
286
297
  ("gate_up_proj", "up_proj", 1),
287
298
  ]
288
299
  params_dict = dict(self.named_parameters())
289
- for name, loaded_weight in hf_model_weights_iterator(
290
- model_name_or_path, cache_dir, load_format, revision
291
- ):
300
+ if get_tensor_model_parallel_rank() == 0:
301
+ weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
302
+ for name, loaded_weight in weights:
292
303
  if "rotary_emb.inv_freq" in name or "projector" in name:
293
304
  continue
294
305
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -0,0 +1,104 @@
1
+ from typing import Iterable, Optional, Tuple
2
+
3
+ import torch
4
+ import tqdm
5
+ from torch import nn
6
+ from transformers import LlamaConfig
7
+ from vllm.config import CacheConfig
8
+ from vllm.distributed import (
9
+ get_tensor_model_parallel_rank,
10
+ )
11
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
12
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
13
+
14
+ from sglang.srt.managers.controller.model_runner import InputMetadata
15
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
16
+ from sglang.srt.models.llama2 import LlamaModel
17
+
18
+
19
+ class LlamaForClassification(nn.Module):
20
+ def __init__(
21
+ self,
22
+ config: LlamaConfig,
23
+ quant_config: Optional[QuantizationConfig] = None,
24
+ cache_config: Optional[CacheConfig] = None,
25
+ ) -> None:
26
+ super().__init__()
27
+ self.config = config
28
+ self.quant_config = quant_config
29
+ self.model = LlamaModel(config, quant_config=quant_config)
30
+
31
+ self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size)
32
+ self.eos_token_id = config.eos_token_id
33
+
34
+ def forward(
35
+ self,
36
+ input_ids: torch.Tensor,
37
+ positions: torch.Tensor,
38
+ input_metadata: InputMetadata,
39
+ input_embeds: torch.Tensor = None,
40
+ ) -> torch.Tensor:
41
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
42
+ is_eos_token = input_ids == self.eos_token_id
43
+ hidden_states = hidden_states[is_eos_token]
44
+ scores = self.classification_head(hidden_states)
45
+
46
+ if scores.shape[0] != input_metadata.batch_size:
47
+ print("Warning: the EOS tokens are missing in some sentences.")
48
+ scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device)
49
+
50
+ return LogitProcessorOutput(
51
+ next_token_logits=scores,
52
+ next_token_logprobs=scores,
53
+ normalized_prompt_logprobs=scores,
54
+ prefill_token_logprobs=torch.ones_like(input_ids),
55
+ prefill_top_logprobs=None,
56
+ decode_top_logprobs=None,
57
+ )
58
+
59
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
60
+ stacked_params_mapping = [
61
+ # (param_name, shard_name, shard_id)
62
+ ("qkv_proj", "q_proj", "q"),
63
+ ("qkv_proj", "k_proj", "k"),
64
+ ("qkv_proj", "v_proj", "v"),
65
+ ("gate_up_proj", "gate_proj", 0),
66
+ ("gate_up_proj", "up_proj", 1),
67
+ ]
68
+ params_dict = dict(self.named_parameters())
69
+ if get_tensor_model_parallel_rank() == 0:
70
+ weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
71
+ for name, loaded_weight in weights:
72
+ if "rotary_emb.inv_freq" in name or "projector" in name:
73
+ continue
74
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
75
+ # Models trained using ColossalAI may include these tensors in
76
+ # the checkpoint. Skip them.
77
+ continue
78
+ if "lm_head" in name:
79
+ continue
80
+
81
+ for param_name, weight_name, shard_id in stacked_params_mapping:
82
+ if weight_name not in name:
83
+ continue
84
+ name = name.replace(weight_name, param_name)
85
+ # Skip loading extra bias for GPTQ models.
86
+ if name.endswith(".bias") and name not in params_dict:
87
+ continue
88
+ if name.startswith("model.vision_tower") and name not in params_dict:
89
+ continue
90
+ param = params_dict[name]
91
+ weight_loader = param.weight_loader
92
+ weight_loader(param, loaded_weight, shard_id)
93
+ break
94
+ else:
95
+ # Skip loading extra bias for GPTQ models.
96
+ if name.endswith(".bias") and name not in params_dict:
97
+ continue
98
+ if name.startswith("model.vision_tower") and name not in params_dict:
99
+ continue
100
+ param = params_dict[name]
101
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
102
+ weight_loader(param, loaded_weight)
103
+
104
+ EntryClass = LlamaForClassification
@@ -1,23 +1,32 @@
1
1
  """Inference-only LLaVa model compatible with HuggingFace weights."""
2
2
 
3
- from typing import List, Optional
3
+ from typing import Iterable, List, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
7
  from torch import nn
8
- from transformers import CLIPVisionModel, LlavaConfig
8
+ from transformers import (
9
+ CLIPVisionConfig,
10
+ CLIPVisionModel,
11
+ LlavaConfig,
12
+ MistralConfig,
13
+ Qwen2Config,
14
+ )
9
15
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
16
+ from vllm.config import CacheConfig
10
17
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
18
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
11
19
 
12
- from sglang.srt.managers.router.infer_batch import ForwardMode
13
- from sglang.srt.managers.router.model_runner import InputMetadata
20
+ from sglang.srt.managers.controller.infer_batch import ForwardMode
21
+ from sglang.srt.managers.controller.model_runner import InputMetadata
14
22
  from sglang.srt.mm_utils import (
15
23
  get_anyres_image_grid_shape,
16
24
  unpad_image,
17
25
  unpad_image_shape,
18
26
  )
19
27
  from sglang.srt.models.llama2 import LlamaForCausalLM
20
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
28
+ from sglang.srt.models.mistral import MistralForCausalLM
29
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
21
30
 
22
31
 
23
32
  class LlavaLlamaForCausalLM(nn.Module):
@@ -25,6 +34,7 @@ class LlavaLlamaForCausalLM(nn.Module):
25
34
  self,
26
35
  config: LlavaConfig,
27
36
  quant_config: Optional[QuantizationConfig] = None,
37
+ cache_config: Optional[CacheConfig] = None,
28
38
  ) -> None:
29
39
  super().__init__()
30
40
  self.config = config
@@ -233,13 +243,7 @@ class LlavaLlamaForCausalLM(nn.Module):
233
243
  elif input_metadata.forward_mode == ForwardMode.DECODE:
234
244
  return self.language_model(input_ids, positions, input_metadata)
235
245
 
236
- def load_weights(
237
- self,
238
- model_name_or_path: str,
239
- cache_dir: Optional[str] = None,
240
- load_format: str = "auto",
241
- revision: Optional[str] = None,
242
- ):
246
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
243
247
  # load clip vision model by cfg['mm_vision_tower']:
244
248
  # huggingface_name or path_of_clip_relative_to_llava_model_dir
245
249
  vision_path = self.config.mm_vision_tower
@@ -272,9 +276,8 @@ class LlavaLlamaForCausalLM(nn.Module):
272
276
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
273
277
  }
274
278
  params_dict = dict(self.named_parameters())
275
- for name, loaded_weight in hf_model_weights_iterator(
276
- model_name_or_path, cache_dir, load_format, revision
277
- ):
279
+ weights = list(weights)
280
+ for name, loaded_weight in weights:
278
281
  # FIXME: why projector weights read two times?
279
282
  if "projector" in name or "vision_tower" in name:
280
283
  for weight_name, param_name in projector_weights.items():
@@ -285,9 +288,7 @@ class LlavaLlamaForCausalLM(nn.Module):
285
288
  weight_loader(param, loaded_weight)
286
289
 
287
290
  # load language model
288
- self.language_model.load_weights(
289
- model_name_or_path, cache_dir, load_format, revision
290
- )
291
+ self.language_model.load_weights(weights)
291
292
 
292
293
  monkey_path_clip_vision_embed_forward()
293
294
 
@@ -296,6 +297,72 @@ class LlavaLlamaForCausalLM(nn.Module):
296
297
  return self.image_size // self.patch_size
297
298
 
298
299
 
300
+ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
301
+ def __init__(
302
+ self,
303
+ config: LlavaConfig,
304
+ quant_config: Optional[QuantizationConfig] = None,
305
+ cache_config: Optional[CacheConfig] = None,
306
+ ) -> None:
307
+ super().__init__(config, quant_config=quant_config, cache_config=cache_config)
308
+ self.config = config
309
+ self.vision_tower = None
310
+ if getattr(self.config, "vision_config", None) is None:
311
+ self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
312
+
313
+ if getattr(self.config, "text_config", None) is None:
314
+ self.config.text_config = Qwen2Config(self.config._name_or_path)
315
+
316
+ self.config.vision_config.hidden_size = config.mm_hidden_size
317
+ self.config.text_config.hidden_size = config.hidden_size
318
+
319
+ if getattr(self.config, "projector_hidden_act", None) is None:
320
+ self.config.projector_hidden_act = "gelu"
321
+
322
+ if getattr(self.config, "image_token_index", None) is None:
323
+ self.config.image_token_index = 151646
324
+
325
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
326
+ self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
327
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
328
+ self.language_model.model.image_newline = nn.Parameter(
329
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
330
+ )
331
+
332
+
333
+ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
334
+ def __init__(
335
+ self,
336
+ config: LlavaConfig,
337
+ quant_config: Optional[QuantizationConfig] = None,
338
+ cache_config: Optional[CacheConfig] = None,
339
+ ) -> None:
340
+ super().__init__(config, quant_config=quant_config, cache_config=cache_config)
341
+ self.config = config
342
+ self.vision_tower = None
343
+ if getattr(self.config, "vision_config", None) is None:
344
+ self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
345
+
346
+ if getattr(self.config, "text_config", None) is None:
347
+ self.config.text_config = MistralConfig(self.config._name_or_path)
348
+
349
+ self.config.vision_config.hidden_size = config.mm_hidden_size
350
+ self.config.text_config.hidden_size = config.hidden_size
351
+
352
+ if getattr(self.config, "projector_hidden_act", None) is None:
353
+ self.config.projector_hidden_act = "gelu"
354
+
355
+ if getattr(self.config, "image_token_index", None) is None:
356
+ self.config.image_token_index = 32000
357
+
358
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
359
+ self.language_model = MistralForCausalLM(config, quant_config=quant_config)
360
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
361
+ self.language_model.model.image_newline = nn.Parameter(
362
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
363
+ )
364
+
365
+
299
366
  first_call = True
300
367
 
301
368
 
@@ -328,4 +395,4 @@ def monkey_path_clip_vision_embed_forward():
328
395
  )
329
396
 
330
397
 
331
- EntryClass = LlavaLlamaForCausalLM
398
+ EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
@@ -1,24 +1,24 @@
1
1
  """Inference-only LLaVa video model compatible with HuggingFace weights."""
2
2
 
3
- import os
4
- from typing import List, Optional
3
+ from typing import Iterable, List, Optional, Tuple
5
4
 
6
5
  import numpy as np
7
6
  import torch
8
7
  from torch import nn
9
- from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
8
+ from transformers import CLIPVisionModel, LlavaConfig
10
9
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
10
+ from vllm.config import CacheConfig
11
11
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
12
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
12
13
 
13
- from sglang.srt.managers.router.infer_batch import ForwardMode
14
- from sglang.srt.managers.router.model_runner import InputMetadata
14
+ from sglang.srt.managers.controller.infer_batch import ForwardMode
15
+ from sglang.srt.managers.controller.model_runner import InputMetadata
15
16
  from sglang.srt.mm_utils import (
16
17
  get_anyres_image_grid_shape,
17
18
  unpad_image,
18
19
  unpad_image_shape,
19
20
  )
20
21
  from sglang.srt.models.llama2 import LlamaForCausalLM
21
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
22
22
 
23
23
 
24
24
  class LlavaVidForCausalLM(nn.Module):
@@ -26,6 +26,7 @@ class LlavaVidForCausalLM(nn.Module):
26
26
  self,
27
27
  config: LlavaConfig,
28
28
  quant_config: Optional[QuantizationConfig] = None,
29
+ cache_config: Optional[CacheConfig] = None,
29
30
  ) -> None:
30
31
  super().__init__()
31
32
  self.config = config
@@ -65,7 +66,6 @@ class LlavaVidForCausalLM(nn.Module):
65
66
  pad_ids = pad_value * (
66
67
  (new_image_feature_len + len(pad_value)) // len(pad_value)
67
68
  )
68
- # print(input_ids)
69
69
  offset = input_ids.index(self.config.image_token_index)
70
70
  # old_len + pad_len - 1, because we need to remove image_token_id
71
71
  new_input_ids = (
@@ -200,13 +200,7 @@ class LlavaVidForCausalLM(nn.Module):
200
200
  elif input_metadata.forward_mode == ForwardMode.DECODE:
201
201
  return self.language_model(input_ids, positions, input_metadata)
202
202
 
203
- def load_weights(
204
- self,
205
- model_name_or_path: str,
206
- cache_dir: Optional[str] = None,
207
- load_format: str = "auto",
208
- revision: Optional[str] = None,
209
- ):
203
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
210
204
  # load clip vision model by cfg['mm_vision_tower']:
211
205
  # huggingface_name or path_of_clip_relative_to_llava_model_dir
212
206
  vision_path = self.config.mm_vision_tower
@@ -244,9 +238,8 @@ class LlavaVidForCausalLM(nn.Module):
244
238
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
245
239
  }
246
240
  params_dict = dict(self.named_parameters())
247
- for name, loaded_weight in hf_model_weights_iterator(
248
- model_name_or_path, cache_dir, load_format, revision
249
- ):
241
+ weights = list(weights)
242
+ for name, loaded_weight in weights:
250
243
  # FIXME: why projector weights read two times?
251
244
  if "projector" in name or "vision_tower" in name:
252
245
  for weight_name, param_name in projector_weights.items():
@@ -261,9 +254,7 @@ class LlavaVidForCausalLM(nn.Module):
261
254
  weight_loader(param, loaded_weight)
262
255
 
263
256
  # load language model
264
- self.language_model.load_weights(
265
- model_name_or_path, cache_dir, load_format, revision
266
- )
257
+ self.language_model.load_weights(weights)
267
258
 
268
259
  monkey_path_clip_vision_embed_forward()
269
260