sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,14 @@ import dataclasses
|
|
4
4
|
import logging
|
5
5
|
import multiprocessing as mp
|
6
6
|
import os
|
7
|
-
from typing import List
|
7
|
+
from typing import List, Dict
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import transformers
|
11
11
|
import uvloop
|
12
12
|
import zmq
|
13
13
|
import zmq.asyncio
|
14
|
+
from fastapi import BackgroundTasks
|
14
15
|
|
15
16
|
from sglang.srt.hf_transformers_utils import (
|
16
17
|
get_config,
|
@@ -19,16 +20,18 @@ from sglang.srt.hf_transformers_utils import (
|
|
19
20
|
get_tokenizer,
|
20
21
|
)
|
21
22
|
from sglang.srt.managers.io_struct import (
|
23
|
+
AbortReq,
|
22
24
|
BatchStrOut,
|
23
|
-
DetokenizeReqInput,
|
24
25
|
FlushCacheReq,
|
25
26
|
GenerateReqInput,
|
26
27
|
TokenizedGenerateReqInput,
|
27
28
|
)
|
29
|
+
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
28
30
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
29
31
|
from sglang.srt.sampling_params import SamplingParams
|
30
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
-
from sglang.srt.utils import
|
33
|
+
from sglang.srt.utils import is_multimodal_model, load_image
|
34
|
+
from sglang.utils import get_exception_traceback
|
32
35
|
|
33
36
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
34
37
|
|
@@ -42,48 +45,12 @@ class ReqState:
|
|
42
45
|
event: asyncio.Event
|
43
46
|
|
44
47
|
|
45
|
-
global global_processor
|
46
|
-
|
47
|
-
|
48
|
-
def init_global_processor(server_args: ServerArgs):
|
49
|
-
global global_processor
|
50
|
-
transformers.logging.set_verbosity_error()
|
51
|
-
global_processor = get_processor(
|
52
|
-
server_args.tokenizer_path,
|
53
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
54
|
-
trust_remote_code=server_args.trust_remote_code,
|
55
|
-
)
|
56
|
-
|
57
|
-
|
58
|
-
def get_pixel_values(
|
59
|
-
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
60
|
-
):
|
61
|
-
try:
|
62
|
-
processor = processor or global_processor
|
63
|
-
image = load_image(image_data)
|
64
|
-
image_hash = hash(image_data)
|
65
|
-
if image_aspect_ratio == "pad":
|
66
|
-
image = expand2square(
|
67
|
-
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
68
|
-
)
|
69
|
-
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
70
|
-
elif image_aspect_ratio == "anyres":
|
71
|
-
pixel_values = process_anyres_image(
|
72
|
-
image, processor.image_processor, image_grid_pinpoints
|
73
|
-
)
|
74
|
-
else:
|
75
|
-
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
76
|
-
pixel_values = pixel_values.astype(np.float16)
|
77
|
-
return pixel_values, image_hash, image.size
|
78
|
-
except Exception:
|
79
|
-
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
80
|
-
|
81
|
-
|
82
48
|
class TokenizerManager:
|
83
49
|
def __init__(
|
84
50
|
self,
|
85
51
|
server_args: ServerArgs,
|
86
52
|
port_args: PortArgs,
|
53
|
+
model_overide_args: dict = None,
|
87
54
|
):
|
88
55
|
self.server_args = server_args
|
89
56
|
|
@@ -96,9 +63,10 @@ class TokenizerManager:
|
|
96
63
|
|
97
64
|
self.model_path = server_args.model_path
|
98
65
|
self.hf_config = get_config(
|
99
|
-
self.model_path,
|
66
|
+
self.model_path,
|
67
|
+
trust_remote_code=server_args.trust_remote_code,
|
68
|
+
model_overide_args=model_overide_args,
|
100
69
|
)
|
101
|
-
|
102
70
|
self.context_len = get_context_length(self.hf_config)
|
103
71
|
|
104
72
|
if is_multimodal_model(self.model_path):
|
@@ -122,7 +90,7 @@ class TokenizerManager:
|
|
122
90
|
)
|
123
91
|
|
124
92
|
self.to_create_loop = True
|
125
|
-
self.rid_to_state
|
93
|
+
self.rid_to_state: Dict[str, ReqState] = {}
|
126
94
|
|
127
95
|
async def get_pixel_values(self, image_data):
|
128
96
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
@@ -143,15 +111,26 @@ class TokenizerManager:
|
|
143
111
|
image_data, aspect_ratio, grid_pinpoints, self.processor
|
144
112
|
)
|
145
113
|
|
146
|
-
async def generate_request(self, obj: GenerateReqInput):
|
114
|
+
async def generate_request(self, obj: GenerateReqInput, request=None):
|
147
115
|
if self.to_create_loop:
|
148
|
-
|
149
|
-
|
150
|
-
is_single = isinstance(obj.text, str)
|
116
|
+
self.create_handle_loop()
|
151
117
|
|
118
|
+
obj.post_init()
|
119
|
+
is_single = obj.is_single
|
152
120
|
if is_single:
|
153
121
|
rid = obj.rid
|
154
|
-
|
122
|
+
|
123
|
+
if obj.input_ids is None:
|
124
|
+
input_ids = self.tokenizer.encode(obj.text)
|
125
|
+
else:
|
126
|
+
input_ids = obj.input_ids
|
127
|
+
|
128
|
+
if len(input_ids) >= self.context_len:
|
129
|
+
raise ValueError(
|
130
|
+
f"The input ({len(input_ids)} tokens) is longer than the "
|
131
|
+
f"model's context length ({self.context_len} tokens)."
|
132
|
+
)
|
133
|
+
|
155
134
|
sampling_params = SamplingParams(**obj.sampling_params)
|
156
135
|
if sampling_params.max_new_tokens != 0:
|
157
136
|
sampling_params.normalize(self.tokenizer)
|
@@ -187,27 +166,54 @@ class TokenizerManager:
|
|
187
166
|
self.rid_to_state[rid] = state
|
188
167
|
|
189
168
|
while True:
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
169
|
+
try:
|
170
|
+
await asyncio.wait_for(event.wait(), timeout=4)
|
171
|
+
except asyncio.TimeoutError:
|
172
|
+
if request is not None and await request.is_disconnected():
|
173
|
+
self.abort_request(rid)
|
174
|
+
raise ValueError(f"Abort request {rid}")
|
175
|
+
continue
|
176
|
+
|
177
|
+
out = self.convert_logprob_style(
|
178
|
+
state.out_list[-1],
|
179
|
+
obj.return_logprob,
|
180
|
+
obj.top_logprobs_num,
|
181
|
+
obj.return_text_in_logprobs,
|
182
|
+
)
|
195
183
|
|
196
184
|
if self.server_args.log_requests and state.finished:
|
197
185
|
logger.info(f"in={obj.text}, out={out}")
|
198
186
|
|
199
|
-
yield out
|
200
187
|
state.out_list = []
|
201
188
|
if state.finished:
|
202
189
|
del self.rid_to_state[rid]
|
190
|
+
|
191
|
+
yield out
|
192
|
+
|
203
193
|
break
|
194
|
+
|
204
195
|
event.clear()
|
196
|
+
|
197
|
+
yield out
|
205
198
|
else:
|
206
|
-
|
207
|
-
|
199
|
+
if obj.stream:
|
200
|
+
raise ValueError("Do not support stream for batch mode.")
|
201
|
+
|
202
|
+
if obj.input_ids is None:
|
203
|
+
bs = len(obj.text)
|
204
|
+
else:
|
205
|
+
bs = len(obj.input_ids)
|
206
|
+
|
208
207
|
for i in range(bs):
|
209
208
|
rid = obj.rid[i]
|
210
|
-
|
209
|
+
|
210
|
+
if obj.input_ids is None:
|
211
|
+
input_text = obj.text[i]
|
212
|
+
input_ids = self.tokenizer.encode(obj.text[i])
|
213
|
+
else:
|
214
|
+
input_text = None
|
215
|
+
input_ids = obj.input_ids[i]
|
216
|
+
|
211
217
|
sampling_params = SamplingParams(**obj.sampling_params[i])
|
212
218
|
if sampling_params.max_new_tokens != 0:
|
213
219
|
sampling_params.normalize(self.tokenizer)
|
@@ -220,7 +226,7 @@ class TokenizerManager:
|
|
220
226
|
)
|
221
227
|
tokenized_obj = TokenizedGenerateReqInput(
|
222
228
|
rid=rid,
|
223
|
-
input_text=
|
229
|
+
input_text=input_text,
|
224
230
|
input_ids=input_ids,
|
225
231
|
pixel_values=pixel_values,
|
226
232
|
image_hash=image_hash,
|
@@ -241,45 +247,84 @@ class TokenizerManager:
|
|
241
247
|
for i in range(bs):
|
242
248
|
rid = obj.rid[i]
|
243
249
|
state = self.rid_to_state[rid]
|
244
|
-
|
250
|
+
|
251
|
+
while True:
|
252
|
+
try:
|
253
|
+
await asyncio.wait_for(state.event.wait(), timeout=4)
|
254
|
+
break
|
255
|
+
except asyncio.TimeoutError:
|
256
|
+
if request is not None and await request.is_disconnected():
|
257
|
+
for rid in obj.rid:
|
258
|
+
self.abort_request(rid)
|
259
|
+
raise ValueError(f"Abort request {rid}")
|
260
|
+
continue
|
261
|
+
|
245
262
|
output_list.append(
|
246
|
-
self.convert_logprob_style(
|
247
|
-
|
248
|
-
|
249
|
-
|
263
|
+
self.convert_logprob_style(
|
264
|
+
state.out_list[-1],
|
265
|
+
obj.return_logprob[i],
|
266
|
+
obj.top_logprobs_num[i],
|
267
|
+
obj.return_text_in_logprobs,
|
268
|
+
)
|
269
|
+
)
|
250
270
|
assert state.finished
|
251
271
|
del self.rid_to_state[rid]
|
252
272
|
|
253
273
|
yield output_list
|
254
274
|
|
255
|
-
|
256
|
-
|
257
|
-
self.send_to_router.send_pyobj(
|
275
|
+
def flush_cache(self):
|
276
|
+
req = FlushCacheReq()
|
277
|
+
self.send_to_router.send_pyobj(req)
|
278
|
+
|
279
|
+
def abort_request(self, rid):
|
280
|
+
if rid not in self.rid_to_state:
|
281
|
+
return
|
282
|
+
del self.rid_to_state[rid]
|
283
|
+
req = AbortReq(rid)
|
284
|
+
self.send_to_router.send_pyobj(req)
|
285
|
+
|
286
|
+
def create_abort_task(self, obj):
|
287
|
+
# Abort the request if the client is disconnected.
|
288
|
+
async def abort_request():
|
289
|
+
await asyncio.sleep(3)
|
290
|
+
if obj.is_single:
|
291
|
+
self.abort_request(obj.rid)
|
292
|
+
else:
|
293
|
+
for rid in obj.rids:
|
294
|
+
self.abort_request(rid)
|
258
295
|
|
259
|
-
|
296
|
+
background_tasks = BackgroundTasks()
|
297
|
+
background_tasks.add_task(abort_request)
|
298
|
+
return background_tasks
|
299
|
+
|
300
|
+
def create_handle_loop(self):
|
260
301
|
self.to_create_loop = False
|
261
302
|
loop = asyncio.get_event_loop()
|
262
303
|
loop.create_task(self.handle_loop())
|
263
304
|
|
264
305
|
async def handle_loop(self):
|
265
306
|
while True:
|
266
|
-
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
307
|
+
recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
|
308
|
+
assert isinstance(recv_obj, BatchStrOut)
|
309
|
+
|
310
|
+
for i, rid in enumerate(recv_obj.rids):
|
311
|
+
state = self.rid_to_state.get(rid, None)
|
312
|
+
if state is None:
|
313
|
+
continue
|
314
|
+
|
315
|
+
recv_obj.meta_info[i]["id"] = rid
|
316
|
+
out_dict = {
|
317
|
+
"text": recv_obj.output_str[i],
|
318
|
+
"meta_info": recv_obj.meta_info[i],
|
319
|
+
}
|
320
|
+
state.out_list.append(out_dict)
|
321
|
+
state.finished = recv_obj.finished_reason[i] is not None
|
322
|
+
state.event.set()
|
323
|
+
|
324
|
+
|
325
|
+
def convert_logprob_style(
|
326
|
+
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
327
|
+
):
|
283
328
|
if return_logprob:
|
284
329
|
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
285
330
|
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
@@ -288,11 +333,15 @@ class TokenizerManager:
|
|
288
333
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
289
334
|
)
|
290
335
|
if top_logprobs_num > 0:
|
291
|
-
ret["meta_info"]["prefill_top_logprobs"] =
|
292
|
-
|
336
|
+
ret["meta_info"]["prefill_top_logprobs"] = (
|
337
|
+
self.detokenize_top_logprobs_tokens(
|
338
|
+
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
339
|
+
)
|
293
340
|
)
|
294
|
-
ret["meta_info"]["decode_top_logprobs"] =
|
295
|
-
|
341
|
+
ret["meta_info"]["decode_top_logprobs"] = (
|
342
|
+
self.detokenize_top_logprobs_tokens(
|
343
|
+
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
344
|
+
)
|
296
345
|
)
|
297
346
|
return ret
|
298
347
|
|
@@ -312,3 +361,49 @@ class TokenizerManager:
|
|
312
361
|
if t:
|
313
362
|
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
314
363
|
return top_logprobs
|
364
|
+
|
365
|
+
|
366
|
+
global global_processor
|
367
|
+
|
368
|
+
|
369
|
+
def init_global_processor(server_args: ServerArgs):
|
370
|
+
global global_processor
|
371
|
+
transformers.logging.set_verbosity_error()
|
372
|
+
global_processor = get_processor(
|
373
|
+
server_args.tokenizer_path,
|
374
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
375
|
+
trust_remote_code=server_args.trust_remote_code,
|
376
|
+
)
|
377
|
+
|
378
|
+
|
379
|
+
def get_pixel_values(
|
380
|
+
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
381
|
+
):
|
382
|
+
try:
|
383
|
+
processor = processor or global_processor
|
384
|
+
image, image_size = load_image(image_data)
|
385
|
+
if image_size != None:
|
386
|
+
image_hash = hash(image_data)
|
387
|
+
pixel_values = processor.image_processor(image)["pixel_values"]
|
388
|
+
for _ in range(len(pixel_values)):
|
389
|
+
pixel_values[_] = pixel_values[_].astype(np.float16)
|
390
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
391
|
+
return pixel_values, image_hash, image_size
|
392
|
+
else:
|
393
|
+
image_hash = hash(image_data)
|
394
|
+
if image_aspect_ratio == "pad":
|
395
|
+
image = expand2square(
|
396
|
+
image,
|
397
|
+
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
398
|
+
)
|
399
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
400
|
+
elif image_aspect_ratio == "anyres":
|
401
|
+
pixel_values = process_anyres_image(
|
402
|
+
image, processor.image_processor, image_grid_pinpoints
|
403
|
+
)
|
404
|
+
else:
|
405
|
+
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
406
|
+
pixel_values = pixel_values.astype(np.float16)
|
407
|
+
return pixel_values, image_hash, image.size
|
408
|
+
except Exception:
|
409
|
+
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
sglang/srt/model_config.py
CHANGED
@@ -10,11 +10,14 @@ class ModelConfig:
|
|
10
10
|
trust_remote_code: bool = True,
|
11
11
|
revision: Optional[str] = None,
|
12
12
|
context_length: Optional[int] = None,
|
13
|
+
model_overide_args: Optional[dict] = None,
|
13
14
|
) -> None:
|
14
15
|
self.path = path
|
15
16
|
self.trust_remote_code = trust_remote_code
|
16
17
|
self.revision = revision
|
17
|
-
self.
|
18
|
+
self.model_overide_args = model_overide_args
|
19
|
+
self.hf_config = get_config(self.path, trust_remote_code, revision,
|
20
|
+
model_overide_args=model_overide_args)
|
18
21
|
|
19
22
|
if context_length is not None:
|
20
23
|
self.context_len = context_length
|
@@ -40,4 +43,4 @@ class ModelConfig:
|
|
40
43
|
self.num_key_value_heads = self.num_attention_heads
|
41
44
|
self.hidden_size = self.hf_config.hidden_size
|
42
45
|
self.num_hidden_layers = self.hf_config.num_hidden_layers
|
43
|
-
self.vocab_size = self.hf_config.vocab_size
|
46
|
+
self.vocab_size = self.hf_config.vocab_size
|
sglang/srt/models/commandr.py
CHANGED
@@ -18,38 +18,38 @@
|
|
18
18
|
# See the License for the specific language governing permissions and
|
19
19
|
# limitations under the License.
|
20
20
|
|
21
|
+
# Adapted from
|
22
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1
|
23
|
+
|
21
24
|
# This file is based on the LLama model definition file in transformers
|
22
25
|
"""PyTorch Cohere model."""
|
23
|
-
from typing import Optional, Tuple
|
26
|
+
from typing import Optional, Tuple, Iterable
|
24
27
|
|
25
28
|
import torch
|
26
29
|
import torch.utils.checkpoint
|
27
30
|
from torch import nn
|
28
31
|
from torch.nn.parameter import Parameter
|
29
32
|
from transformers import PretrainedConfig
|
33
|
+
from vllm.config import CacheConfig
|
34
|
+
from vllm.distributed import (
|
35
|
+
get_tensor_model_parallel_rank,
|
36
|
+
get_tensor_model_parallel_world_size,
|
37
|
+
)
|
30
38
|
from vllm.model_executor.layers.activation import SiluAndMul
|
31
39
|
from vllm.model_executor.layers.linear import (
|
32
40
|
MergedColumnParallelLinear,
|
33
41
|
QKVParallelLinear,
|
34
42
|
RowParallelLinear,
|
35
43
|
)
|
36
|
-
from vllm.model_executor.layers.quantization.base_config import
|
37
|
-
QuantizationConfig)
|
44
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
38
45
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
39
46
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
|
-
from vllm.distributed import (
|
41
|
-
get_tensor_model_parallel_rank,
|
42
|
-
get_tensor_model_parallel_world_size,
|
43
|
-
)
|
44
47
|
from vllm.model_executor.utils import set_weight_attrs
|
45
|
-
from
|
46
|
-
default_weight_loader,
|
47
|
-
hf_model_weights_iterator,
|
48
|
-
)
|
48
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
49
49
|
|
50
50
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
51
51
|
from sglang.srt.layers.radix_attention import RadixAttention
|
52
|
-
from sglang.srt.managers.
|
52
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
53
53
|
|
54
54
|
|
55
55
|
@torch.compile
|
@@ -305,6 +305,7 @@ class CohereForCausalLM(nn.Module):
|
|
305
305
|
self,
|
306
306
|
config: PretrainedConfig,
|
307
307
|
quant_config: Optional[QuantizationConfig] = None,
|
308
|
+
cache_config: Optional[CacheConfig] = None,
|
308
309
|
) -> None:
|
309
310
|
super().__init__()
|
310
311
|
self.config = config
|
@@ -328,13 +329,7 @@ class CohereForCausalLM(nn.Module):
|
|
328
329
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
329
330
|
)
|
330
331
|
|
331
|
-
def load_weights(
|
332
|
-
self,
|
333
|
-
model_name_or_path: str,
|
334
|
-
cache_dir: Optional[str] = None,
|
335
|
-
load_format: str = "auto",
|
336
|
-
revision: Optional[str] = None,
|
337
|
-
):
|
332
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
338
333
|
stacked_params_mapping = [
|
339
334
|
# (param_name, shard_name, shard_id)
|
340
335
|
("qkv_proj", "q_proj", "q"),
|
@@ -345,9 +340,7 @@ class CohereForCausalLM(nn.Module):
|
|
345
340
|
]
|
346
341
|
params_dict = dict(self.named_parameters())
|
347
342
|
loaded_params = set()
|
348
|
-
for name, loaded_weight in
|
349
|
-
model_name_or_path, cache_dir, load_format, revision
|
350
|
-
):
|
343
|
+
for name, loaded_weight in weights:
|
351
344
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
352
345
|
if shard_name not in name:
|
353
346
|
continue
|
sglang/srt/models/dbrx.py
CHANGED
@@ -1,41 +1,36 @@
|
|
1
1
|
# Adapted from:
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
|
3
3
|
# coding=utf-8
|
4
|
-
from typing import Optional
|
4
|
+
from typing import Iterable, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn as nn
|
8
|
+
from vllm.config import CacheConfig
|
9
|
+
from vllm.distributed import (
|
10
|
+
get_tensor_model_parallel_rank,
|
11
|
+
get_tensor_model_parallel_world_size,
|
12
|
+
tensor_model_parallel_all_reduce,
|
13
|
+
)
|
8
14
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
9
15
|
from vllm.model_executor.layers.linear import (
|
10
16
|
QKVParallelLinear,
|
11
17
|
ReplicatedLinear,
|
12
18
|
RowParallelLinear,
|
13
19
|
)
|
14
|
-
from vllm.model_executor.layers.quantization.base_config import
|
15
|
-
QuantizationConfig)
|
20
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
16
21
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
17
22
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
18
23
|
DEFAULT_VOCAB_PADDING_SIZE,
|
19
24
|
ParallelLMHead,
|
20
25
|
VocabParallelEmbedding,
|
21
26
|
)
|
22
|
-
from vllm.distributed import (
|
23
|
-
tensor_model_parallel_all_reduce,
|
24
|
-
)
|
25
|
-
from vllm.distributed import (
|
26
|
-
get_tensor_model_parallel_rank,
|
27
|
-
get_tensor_model_parallel_world_size,
|
28
|
-
)
|
29
27
|
from vllm.model_executor.utils import set_weight_attrs
|
30
|
-
from
|
31
|
-
|
32
|
-
hf_model_weights_iterator,
|
33
|
-
)
|
28
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
29
|
+
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
34
30
|
|
35
31
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
32
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
|
-
from sglang.srt.managers.
|
38
|
-
from sglang.srt.models.dbrx_config import DbrxConfig
|
33
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
39
34
|
|
40
35
|
|
41
36
|
class DbrxRouter(nn.Module):
|
@@ -291,7 +286,9 @@ class DbrxBlock(nn.Module):
|
|
291
286
|
quant_config: Optional[QuantizationConfig] = None,
|
292
287
|
):
|
293
288
|
super().__init__()
|
294
|
-
self.norm_attn_norm = DbrxFusedNormAttention(
|
289
|
+
self.norm_attn_norm = DbrxFusedNormAttention(
|
290
|
+
config, layer_id, quant_config=quant_config
|
291
|
+
)
|
295
292
|
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
296
293
|
|
297
294
|
def forward(
|
@@ -322,7 +319,10 @@ class DbrxModel(nn.Module):
|
|
322
319
|
config.d_model,
|
323
320
|
)
|
324
321
|
self.blocks = nn.ModuleList(
|
325
|
-
[
|
322
|
+
[
|
323
|
+
DbrxBlock(config, i, quant_config=quant_config)
|
324
|
+
for i in range(config.n_layers)
|
325
|
+
]
|
326
326
|
)
|
327
327
|
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
328
328
|
for module in self.modules():
|
@@ -353,6 +353,7 @@ class DbrxForCausalLM(nn.Module):
|
|
353
353
|
self,
|
354
354
|
config: DbrxConfig,
|
355
355
|
quant_config: Optional[QuantizationConfig] = None,
|
356
|
+
cache_config: Optional[CacheConfig] = None,
|
356
357
|
):
|
357
358
|
super().__init__()
|
358
359
|
self.config = config
|
@@ -378,13 +379,7 @@ class DbrxForCausalLM(nn.Module):
|
|
378
379
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
379
380
|
)
|
380
381
|
|
381
|
-
def load_weights(
|
382
|
-
self,
|
383
|
-
model_name_or_path: str,
|
384
|
-
cache_dir: Optional[str] = None,
|
385
|
-
load_format: str = "auto",
|
386
|
-
revision: Optional[str] = None,
|
387
|
-
):
|
382
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
388
383
|
expert_params_mapping = [
|
389
384
|
(
|
390
385
|
"ws" if weight_name in ["w1", "v1"] else "w2s",
|
@@ -393,9 +388,7 @@ class DbrxForCausalLM(nn.Module):
|
|
393
388
|
for weight_name in ["w1", "v1", "w2"]
|
394
389
|
]
|
395
390
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
396
|
-
for name, loaded_weight in
|
397
|
-
model_name_or_path, cache_dir, load_format, revision
|
398
|
-
):
|
391
|
+
for name, loaded_weight in weights:
|
399
392
|
for param_name, weight_name in expert_params_mapping:
|
400
393
|
if weight_name not in name:
|
401
394
|
continue
|