sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,14 @@ import dataclasses
4
4
  import logging
5
5
  import multiprocessing as mp
6
6
  import os
7
- from typing import List
7
+ from typing import List, Dict
8
8
 
9
9
  import numpy as np
10
10
  import transformers
11
11
  import uvloop
12
12
  import zmq
13
13
  import zmq.asyncio
14
+ from fastapi import BackgroundTasks
14
15
 
15
16
  from sglang.srt.hf_transformers_utils import (
16
17
  get_config,
@@ -19,16 +20,18 @@ from sglang.srt.hf_transformers_utils import (
19
20
  get_tokenizer,
20
21
  )
21
22
  from sglang.srt.managers.io_struct import (
23
+ AbortReq,
22
24
  BatchStrOut,
23
- DetokenizeReqInput,
24
25
  FlushCacheReq,
25
26
  GenerateReqInput,
26
27
  TokenizedGenerateReqInput,
27
28
  )
29
+ from sglang.srt.managers.io_struct import BatchTokenIDOut
28
30
  from sglang.srt.mm_utils import expand2square, process_anyres_image
29
31
  from sglang.srt.sampling_params import SamplingParams
30
32
  from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
33
+ from sglang.srt.utils import is_multimodal_model, load_image
34
+ from sglang.utils import get_exception_traceback
32
35
 
33
36
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
34
37
 
@@ -42,48 +45,12 @@ class ReqState:
42
45
  event: asyncio.Event
43
46
 
44
47
 
45
- global global_processor
46
-
47
-
48
- def init_global_processor(server_args: ServerArgs):
49
- global global_processor
50
- transformers.logging.set_verbosity_error()
51
- global_processor = get_processor(
52
- server_args.tokenizer_path,
53
- tokenizer_mode=server_args.tokenizer_mode,
54
- trust_remote_code=server_args.trust_remote_code,
55
- )
56
-
57
-
58
- def get_pixel_values(
59
- image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
60
- ):
61
- try:
62
- processor = processor or global_processor
63
- image = load_image(image_data)
64
- image_hash = hash(image_data)
65
- if image_aspect_ratio == "pad":
66
- image = expand2square(
67
- image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
68
- )
69
- pixel_values = processor.image_processor(image)["pixel_values"][0]
70
- elif image_aspect_ratio == "anyres":
71
- pixel_values = process_anyres_image(
72
- image, processor.image_processor, image_grid_pinpoints
73
- )
74
- else:
75
- pixel_values = processor.image_processor(image)["pixel_values"][0]
76
- pixel_values = pixel_values.astype(np.float16)
77
- return pixel_values, image_hash, image.size
78
- except Exception:
79
- print("Exception in TokenizerManager:\n" + get_exception_traceback())
80
-
81
-
82
48
  class TokenizerManager:
83
49
  def __init__(
84
50
  self,
85
51
  server_args: ServerArgs,
86
52
  port_args: PortArgs,
53
+ model_overide_args: dict = None,
87
54
  ):
88
55
  self.server_args = server_args
89
56
 
@@ -96,9 +63,10 @@ class TokenizerManager:
96
63
 
97
64
  self.model_path = server_args.model_path
98
65
  self.hf_config = get_config(
99
- self.model_path, trust_remote_code=server_args.trust_remote_code
66
+ self.model_path,
67
+ trust_remote_code=server_args.trust_remote_code,
68
+ model_overide_args=model_overide_args,
100
69
  )
101
-
102
70
  self.context_len = get_context_length(self.hf_config)
103
71
 
104
72
  if is_multimodal_model(self.model_path):
@@ -122,7 +90,7 @@ class TokenizerManager:
122
90
  )
123
91
 
124
92
  self.to_create_loop = True
125
- self.rid_to_state = {} # Dict[str -> ReqState]
93
+ self.rid_to_state: Dict[str, ReqState] = {}
126
94
 
127
95
  async def get_pixel_values(self, image_data):
128
96
  aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
@@ -143,15 +111,26 @@ class TokenizerManager:
143
111
  image_data, aspect_ratio, grid_pinpoints, self.processor
144
112
  )
145
113
 
146
- async def generate_request(self, obj: GenerateReqInput):
114
+ async def generate_request(self, obj: GenerateReqInput, request=None):
147
115
  if self.to_create_loop:
148
- await self.create_handle_loop()
149
-
150
- is_single = isinstance(obj.text, str)
116
+ self.create_handle_loop()
151
117
 
118
+ obj.post_init()
119
+ is_single = obj.is_single
152
120
  if is_single:
153
121
  rid = obj.rid
154
- input_ids = self.tokenizer.encode(obj.text)
122
+
123
+ if obj.input_ids is None:
124
+ input_ids = self.tokenizer.encode(obj.text)
125
+ else:
126
+ input_ids = obj.input_ids
127
+
128
+ if len(input_ids) >= self.context_len:
129
+ raise ValueError(
130
+ f"The input ({len(input_ids)} tokens) is longer than the "
131
+ f"model's context length ({self.context_len} tokens)."
132
+ )
133
+
155
134
  sampling_params = SamplingParams(**obj.sampling_params)
156
135
  if sampling_params.max_new_tokens != 0:
157
136
  sampling_params.normalize(self.tokenizer)
@@ -187,27 +166,54 @@ class TokenizerManager:
187
166
  self.rid_to_state[rid] = state
188
167
 
189
168
  while True:
190
- await event.wait()
191
- out = self.convert_logprob_style(state.out_list[-1],
192
- obj.return_logprob,
193
- obj.top_logprobs_num,
194
- obj.return_text_in_logprobs)
169
+ try:
170
+ await asyncio.wait_for(event.wait(), timeout=4)
171
+ except asyncio.TimeoutError:
172
+ if request is not None and await request.is_disconnected():
173
+ self.abort_request(rid)
174
+ raise ValueError(f"Abort request {rid}")
175
+ continue
176
+
177
+ out = self.convert_logprob_style(
178
+ state.out_list[-1],
179
+ obj.return_logprob,
180
+ obj.top_logprobs_num,
181
+ obj.return_text_in_logprobs,
182
+ )
195
183
 
196
184
  if self.server_args.log_requests and state.finished:
197
185
  logger.info(f"in={obj.text}, out={out}")
198
186
 
199
- yield out
200
187
  state.out_list = []
201
188
  if state.finished:
202
189
  del self.rid_to_state[rid]
190
+
191
+ yield out
192
+
203
193
  break
194
+
204
195
  event.clear()
196
+
197
+ yield out
205
198
  else:
206
- assert obj.stream is False
207
- bs = len(obj.text)
199
+ if obj.stream:
200
+ raise ValueError("Do not support stream for batch mode.")
201
+
202
+ if obj.input_ids is None:
203
+ bs = len(obj.text)
204
+ else:
205
+ bs = len(obj.input_ids)
206
+
208
207
  for i in range(bs):
209
208
  rid = obj.rid[i]
210
- input_ids = self.tokenizer.encode(obj.text[i])
209
+
210
+ if obj.input_ids is None:
211
+ input_text = obj.text[i]
212
+ input_ids = self.tokenizer.encode(obj.text[i])
213
+ else:
214
+ input_text = None
215
+ input_ids = obj.input_ids[i]
216
+
211
217
  sampling_params = SamplingParams(**obj.sampling_params[i])
212
218
  if sampling_params.max_new_tokens != 0:
213
219
  sampling_params.normalize(self.tokenizer)
@@ -220,7 +226,7 @@ class TokenizerManager:
220
226
  )
221
227
  tokenized_obj = TokenizedGenerateReqInput(
222
228
  rid=rid,
223
- input_text=obj.text[i],
229
+ input_text=input_text,
224
230
  input_ids=input_ids,
225
231
  pixel_values=pixel_values,
226
232
  image_hash=image_hash,
@@ -241,45 +247,84 @@ class TokenizerManager:
241
247
  for i in range(bs):
242
248
  rid = obj.rid[i]
243
249
  state = self.rid_to_state[rid]
244
- await state.event.wait()
250
+
251
+ while True:
252
+ try:
253
+ await asyncio.wait_for(state.event.wait(), timeout=4)
254
+ break
255
+ except asyncio.TimeoutError:
256
+ if request is not None and await request.is_disconnected():
257
+ for rid in obj.rid:
258
+ self.abort_request(rid)
259
+ raise ValueError(f"Abort request {rid}")
260
+ continue
261
+
245
262
  output_list.append(
246
- self.convert_logprob_style(state.out_list[-1],
247
- obj.return_logprob[i],
248
- obj.top_logprobs_num[i],
249
- obj.return_text_in_logprobs))
263
+ self.convert_logprob_style(
264
+ state.out_list[-1],
265
+ obj.return_logprob[i],
266
+ obj.top_logprobs_num[i],
267
+ obj.return_text_in_logprobs,
268
+ )
269
+ )
250
270
  assert state.finished
251
271
  del self.rid_to_state[rid]
252
272
 
253
273
  yield output_list
254
274
 
255
- async def flush_cache(self):
256
- flush_cache_req = FlushCacheReq()
257
- self.send_to_router.send_pyobj(flush_cache_req)
275
+ def flush_cache(self):
276
+ req = FlushCacheReq()
277
+ self.send_to_router.send_pyobj(req)
278
+
279
+ def abort_request(self, rid):
280
+ if rid not in self.rid_to_state:
281
+ return
282
+ del self.rid_to_state[rid]
283
+ req = AbortReq(rid)
284
+ self.send_to_router.send_pyobj(req)
285
+
286
+ def create_abort_task(self, obj):
287
+ # Abort the request if the client is disconnected.
288
+ async def abort_request():
289
+ await asyncio.sleep(3)
290
+ if obj.is_single:
291
+ self.abort_request(obj.rid)
292
+ else:
293
+ for rid in obj.rids:
294
+ self.abort_request(rid)
258
295
 
259
- async def create_handle_loop(self):
296
+ background_tasks = BackgroundTasks()
297
+ background_tasks.add_task(abort_request)
298
+ return background_tasks
299
+
300
+ def create_handle_loop(self):
260
301
  self.to_create_loop = False
261
302
  loop = asyncio.get_event_loop()
262
303
  loop.create_task(self.handle_loop())
263
304
 
264
305
  async def handle_loop(self):
265
306
  while True:
266
- recv_obj = await self.recv_from_detokenizer.recv_pyobj()
267
-
268
- if isinstance(recv_obj, BatchStrOut):
269
- for i, rid in enumerate(recv_obj.rids):
270
- recv_obj.meta_info[i]["id"] = rid
271
- out_dict = {
272
- "text": recv_obj.output_str[i],
273
- "meta_info": recv_obj.meta_info[i],
274
- }
275
- state = self.rid_to_state[rid]
276
- state.out_list.append(out_dict)
277
- state.finished = recv_obj.finished[i]
278
- state.event.set()
279
- else:
280
- raise ValueError(f"Invalid object: {recv_obj}")
281
-
282
- def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
307
+ recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
308
+ assert isinstance(recv_obj, BatchStrOut)
309
+
310
+ for i, rid in enumerate(recv_obj.rids):
311
+ state = self.rid_to_state.get(rid, None)
312
+ if state is None:
313
+ continue
314
+
315
+ recv_obj.meta_info[i]["id"] = rid
316
+ out_dict = {
317
+ "text": recv_obj.output_str[i],
318
+ "meta_info": recv_obj.meta_info[i],
319
+ }
320
+ state.out_list.append(out_dict)
321
+ state.finished = recv_obj.finished_reason[i] is not None
322
+ state.event.set()
323
+
324
+
325
+ def convert_logprob_style(
326
+ self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
327
+ ):
283
328
  if return_logprob:
284
329
  ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
285
330
  ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
@@ -288,11 +333,15 @@ class TokenizerManager:
288
333
  ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
289
334
  )
290
335
  if top_logprobs_num > 0:
291
- ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
292
- ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
336
+ ret["meta_info"]["prefill_top_logprobs"] = (
337
+ self.detokenize_top_logprobs_tokens(
338
+ ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
339
+ )
293
340
  )
294
- ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
295
- ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
341
+ ret["meta_info"]["decode_top_logprobs"] = (
342
+ self.detokenize_top_logprobs_tokens(
343
+ ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
344
+ )
296
345
  )
297
346
  return ret
298
347
 
@@ -312,3 +361,49 @@ class TokenizerManager:
312
361
  if t:
313
362
  top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
314
363
  return top_logprobs
364
+
365
+
366
+ global global_processor
367
+
368
+
369
+ def init_global_processor(server_args: ServerArgs):
370
+ global global_processor
371
+ transformers.logging.set_verbosity_error()
372
+ global_processor = get_processor(
373
+ server_args.tokenizer_path,
374
+ tokenizer_mode=server_args.tokenizer_mode,
375
+ trust_remote_code=server_args.trust_remote_code,
376
+ )
377
+
378
+
379
+ def get_pixel_values(
380
+ image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
381
+ ):
382
+ try:
383
+ processor = processor or global_processor
384
+ image, image_size = load_image(image_data)
385
+ if image_size != None:
386
+ image_hash = hash(image_data)
387
+ pixel_values = processor.image_processor(image)["pixel_values"]
388
+ for _ in range(len(pixel_values)):
389
+ pixel_values[_] = pixel_values[_].astype(np.float16)
390
+ pixel_values = np.stack(pixel_values, axis=0)
391
+ return pixel_values, image_hash, image_size
392
+ else:
393
+ image_hash = hash(image_data)
394
+ if image_aspect_ratio == "pad":
395
+ image = expand2square(
396
+ image,
397
+ tuple(int(x * 255) for x in processor.image_processor.image_mean),
398
+ )
399
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
400
+ elif image_aspect_ratio == "anyres":
401
+ pixel_values = process_anyres_image(
402
+ image, processor.image_processor, image_grid_pinpoints
403
+ )
404
+ else:
405
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
406
+ pixel_values = pixel_values.astype(np.float16)
407
+ return pixel_values, image_hash, image.size
408
+ except Exception:
409
+ print("Exception in TokenizerManager:\n" + get_exception_traceback())
@@ -10,11 +10,14 @@ class ModelConfig:
10
10
  trust_remote_code: bool = True,
11
11
  revision: Optional[str] = None,
12
12
  context_length: Optional[int] = None,
13
+ model_overide_args: Optional[dict] = None,
13
14
  ) -> None:
14
15
  self.path = path
15
16
  self.trust_remote_code = trust_remote_code
16
17
  self.revision = revision
17
- self.hf_config = get_config(self.path, trust_remote_code, revision)
18
+ self.model_overide_args = model_overide_args
19
+ self.hf_config = get_config(self.path, trust_remote_code, revision,
20
+ model_overide_args=model_overide_args)
18
21
 
19
22
  if context_length is not None:
20
23
  self.context_len = context_length
@@ -40,4 +43,4 @@ class ModelConfig:
40
43
  self.num_key_value_heads = self.num_attention_heads
41
44
  self.hidden_size = self.hf_config.hidden_size
42
45
  self.num_hidden_layers = self.hf_config.num_hidden_layers
43
- self.vocab_size = self.hf_config.vocab_size
46
+ self.vocab_size = self.hf_config.vocab_size
@@ -18,38 +18,38 @@
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
20
 
21
+ # Adapted from
22
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1
23
+
21
24
  # This file is based on the LLama model definition file in transformers
22
25
  """PyTorch Cohere model."""
23
- from typing import Optional, Tuple
26
+ from typing import Optional, Tuple, Iterable
24
27
 
25
28
  import torch
26
29
  import torch.utils.checkpoint
27
30
  from torch import nn
28
31
  from torch.nn.parameter import Parameter
29
32
  from transformers import PretrainedConfig
33
+ from vllm.config import CacheConfig
34
+ from vllm.distributed import (
35
+ get_tensor_model_parallel_rank,
36
+ get_tensor_model_parallel_world_size,
37
+ )
30
38
  from vllm.model_executor.layers.activation import SiluAndMul
31
39
  from vllm.model_executor.layers.linear import (
32
40
  MergedColumnParallelLinear,
33
41
  QKVParallelLinear,
34
42
  RowParallelLinear,
35
43
  )
36
- from vllm.model_executor.layers.quantization.base_config import (
37
- QuantizationConfig)
44
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
38
45
  from vllm.model_executor.layers.rotary_embedding import get_rope
39
46
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
- from vllm.distributed import (
41
- get_tensor_model_parallel_rank,
42
- get_tensor_model_parallel_world_size,
43
- )
44
47
  from vllm.model_executor.utils import set_weight_attrs
45
- from sglang.srt.weight_utils import (
46
- default_weight_loader,
47
- hf_model_weights_iterator,
48
- )
48
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
49
 
50
50
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
51
  from sglang.srt.layers.radix_attention import RadixAttention
52
- from sglang.srt.managers.router.model_runner import InputMetadata
52
+ from sglang.srt.managers.controller.model_runner import InputMetadata
53
53
 
54
54
 
55
55
  @torch.compile
@@ -305,6 +305,7 @@ class CohereForCausalLM(nn.Module):
305
305
  self,
306
306
  config: PretrainedConfig,
307
307
  quant_config: Optional[QuantizationConfig] = None,
308
+ cache_config: Optional[CacheConfig] = None,
308
309
  ) -> None:
309
310
  super().__init__()
310
311
  self.config = config
@@ -328,13 +329,7 @@ class CohereForCausalLM(nn.Module):
328
329
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
329
330
  )
330
331
 
331
- def load_weights(
332
- self,
333
- model_name_or_path: str,
334
- cache_dir: Optional[str] = None,
335
- load_format: str = "auto",
336
- revision: Optional[str] = None,
337
- ):
332
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
338
333
  stacked_params_mapping = [
339
334
  # (param_name, shard_name, shard_id)
340
335
  ("qkv_proj", "q_proj", "q"),
@@ -345,9 +340,7 @@ class CohereForCausalLM(nn.Module):
345
340
  ]
346
341
  params_dict = dict(self.named_parameters())
347
342
  loaded_params = set()
348
- for name, loaded_weight in hf_model_weights_iterator(
349
- model_name_or_path, cache_dir, load_format, revision
350
- ):
343
+ for name, loaded_weight in weights:
351
344
  for param_name, shard_name, shard_id in stacked_params_mapping:
352
345
  if shard_name not in name:
353
346
  continue
sglang/srt/models/dbrx.py CHANGED
@@ -1,41 +1,36 @@
1
1
  # Adapted from:
2
- # https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/model_executor/models/dbrx.py
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
3
3
  # coding=utf-8
4
- from typing import Optional
4
+ from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
8
+ from vllm.config import CacheConfig
9
+ from vllm.distributed import (
10
+ get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ tensor_model_parallel_all_reduce,
13
+ )
8
14
  from vllm.model_executor.layers.fused_moe import fused_moe
9
15
  from vllm.model_executor.layers.linear import (
10
16
  QKVParallelLinear,
11
17
  ReplicatedLinear,
12
18
  RowParallelLinear,
13
19
  )
14
- from vllm.model_executor.layers.quantization.base_config import (
15
- QuantizationConfig)
20
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
16
21
  from vllm.model_executor.layers.rotary_embedding import get_rope
17
22
  from vllm.model_executor.layers.vocab_parallel_embedding import (
18
23
  DEFAULT_VOCAB_PADDING_SIZE,
19
24
  ParallelLMHead,
20
25
  VocabParallelEmbedding,
21
26
  )
22
- from vllm.distributed import (
23
- tensor_model_parallel_all_reduce,
24
- )
25
- from vllm.distributed import (
26
- get_tensor_model_parallel_rank,
27
- get_tensor_model_parallel_world_size,
28
- )
29
27
  from vllm.model_executor.utils import set_weight_attrs
30
- from sglang.srt.weight_utils import (
31
- default_weight_loader,
32
- hf_model_weights_iterator,
33
- )
28
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
+ from vllm.transformers_utils.configs.dbrx import DbrxConfig
34
30
 
35
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
32
  from sglang.srt.layers.radix_attention import RadixAttention
37
- from sglang.srt.managers.router.model_runner import InputMetadata
38
- from sglang.srt.models.dbrx_config import DbrxConfig
33
+ from sglang.srt.managers.controller.model_runner import InputMetadata
39
34
 
40
35
 
41
36
  class DbrxRouter(nn.Module):
@@ -291,7 +286,9 @@ class DbrxBlock(nn.Module):
291
286
  quant_config: Optional[QuantizationConfig] = None,
292
287
  ):
293
288
  super().__init__()
294
- self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, quant_config=quant_config)
289
+ self.norm_attn_norm = DbrxFusedNormAttention(
290
+ config, layer_id, quant_config=quant_config
291
+ )
295
292
  self.ffn = DbrxExperts(config, quant_config=quant_config)
296
293
 
297
294
  def forward(
@@ -322,7 +319,10 @@ class DbrxModel(nn.Module):
322
319
  config.d_model,
323
320
  )
324
321
  self.blocks = nn.ModuleList(
325
- [DbrxBlock(config, i, quant_config=quant_config) for i in range(config.n_layers)]
322
+ [
323
+ DbrxBlock(config, i, quant_config=quant_config)
324
+ for i in range(config.n_layers)
325
+ ]
326
326
  )
327
327
  self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
328
328
  for module in self.modules():
@@ -353,6 +353,7 @@ class DbrxForCausalLM(nn.Module):
353
353
  self,
354
354
  config: DbrxConfig,
355
355
  quant_config: Optional[QuantizationConfig] = None,
356
+ cache_config: Optional[CacheConfig] = None,
356
357
  ):
357
358
  super().__init__()
358
359
  self.config = config
@@ -378,13 +379,7 @@ class DbrxForCausalLM(nn.Module):
378
379
  input_ids, hidden_states, self.lm_head.weight, input_metadata
379
380
  )
380
381
 
381
- def load_weights(
382
- self,
383
- model_name_or_path: str,
384
- cache_dir: Optional[str] = None,
385
- load_format: str = "auto",
386
- revision: Optional[str] = None,
387
- ):
382
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
388
383
  expert_params_mapping = [
389
384
  (
390
385
  "ws" if weight_name in ["w1", "v1"] else "w2s",
@@ -393,9 +388,7 @@ class DbrxForCausalLM(nn.Module):
393
388
  for weight_name in ["w1", "v1", "w2"]
394
389
  ]
395
390
  params_dict = dict(self.named_parameters(remove_duplicate=False))
396
- for name, loaded_weight in hf_model_weights_iterator(
397
- model_name_or_path, cache_dir, load_format, revision
398
- ):
391
+ for name, loaded_weight in weights:
399
392
  for param_name, weight_name in expert_params_mapping:
400
393
  if weight_name not in name:
401
394
  continue