sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,17 @@
1
+ """DetokenizerManager is a process that detokenizes the token ids."""
2
+
1
3
  import asyncio
4
+ import inspect
2
5
 
3
6
  import uvloop
4
7
  import zmq
5
8
  import zmq.asyncio
6
9
 
7
10
  from sglang.srt.hf_transformers_utils import get_tokenizer
11
+ from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
8
12
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
9
13
  from sglang.srt.server_args import PortArgs, ServerArgs
10
- from sglang.srt.utils import get_exception_traceback
14
+ from sglang.utils import get_exception_traceback, graceful_registry
11
15
 
12
16
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
13
17
 
@@ -33,51 +37,41 @@ class DetokenizerManager:
33
37
 
34
38
  async def handle_loop(self):
35
39
  while True:
36
- recv_obj = await self.recv_from_router.recv_pyobj()
37
-
38
- if isinstance(recv_obj, BatchTokenIDOut):
39
- output_tokens = recv_obj.output_tokens
40
-
41
- # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
42
- output_strs = self.tokenizer.batch_decode(
43
- output_tokens,
44
- skip_special_tokens=recv_obj.skip_special_tokens[0],
45
- spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
46
- 0
47
- ],
48
- )
49
-
50
- # Trim stop str
51
- # TODO(lmzheng): handle the case where multiple stop strs are hit
52
- for i in range(len(output_strs)):
53
- if recv_obj.hit_stop_str[i] is not None:
54
- pos = output_strs[i].find(recv_obj.hit_stop_str[i])
55
- if pos != -1:
56
- output_strs[i] = output_strs[i][:pos]
57
-
58
- if len(output_tokens[i]) > 0:
59
- first_token = self.tokenizer.convert_ids_to_tokens(
60
- int(output_tokens[i][0])
61
- )
62
- if not isinstance(first_token, str):
63
- first_token = first_token.decode("utf-8", errors="ignore")
64
- if first_token.startswith("▁"):
65
- output_strs[i] = " " + output_strs[i]
66
-
67
- output_strs[i] = (
68
- recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
69
- )
70
-
71
- self.send_to_tokenizer.send_pyobj(
72
- BatchStrOut(
73
- recv_obj.rids,
74
- output_strs,
75
- recv_obj.meta_info,
76
- recv_obj.finished,
77
- )
40
+ recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
41
+ assert isinstance(recv_obj, BatchTokenIDOut)
42
+
43
+ # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
44
+ surr_texts = self.tokenizer.batch_decode(
45
+ recv_obj.surr_output_ids,
46
+ skip_special_tokens=recv_obj.skip_special_tokens[0],
47
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
48
+ )
49
+ read_texts = self.tokenizer.batch_decode(
50
+ recv_obj.read_output_ids,
51
+ skip_special_tokens=recv_obj.skip_special_tokens[0],
52
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
53
+ )
54
+
55
+ # Trim stop str
56
+ # TODO(lmzheng): handle the case where multiple stop strs are hit
57
+ output_strs = []
58
+ for i in range(len(recv_obj.rids)):
59
+ new_text = read_texts[i][len(surr_texts[i]) :]
60
+ output_strs.append(recv_obj.decoded_texts[i] + new_text)
61
+
62
+ if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
63
+ pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
64
+ if pos != -1:
65
+ output_strs[i] = output_strs[i][:pos]
66
+
67
+ self.send_to_tokenizer.send_pyobj(
68
+ BatchStrOut(
69
+ rids=recv_obj.rids,
70
+ output_str=output_strs,
71
+ meta_info=recv_obj.meta_info,
72
+ finished_reason=recv_obj.finished_reason,
78
73
  )
79
- else:
80
- raise ValueError(f"Invalid object: {recv_obj}")
74
+ )
81
75
 
82
76
 
83
77
  def start_detokenizer_process(
@@ -85,9 +79,11 @@ def start_detokenizer_process(
85
79
  port_args: PortArgs,
86
80
  pipe_writer,
87
81
  ):
82
+ graceful_registry(inspect.currentframe().f_code.co_name)
83
+
88
84
  try:
89
85
  manager = DetokenizerManager(server_args, port_args)
90
- except Exception as e:
86
+ except Exception:
91
87
  pipe_writer.send(get_exception_traceback())
92
88
  raise
93
89
  pipe_writer.send("init ok")
@@ -1,7 +1,13 @@
1
+ """
2
+ The definition of objects transfered between different
3
+ processes (TokenizerManager, DetokenizerManager, Controller).
4
+ """
5
+
1
6
  import uuid
2
7
  from dataclasses import dataclass
3
8
  from typing import Dict, List, Optional, Union
4
9
 
10
+ from sglang.srt.managers.controller.infer_batch import BaseFinishReason
5
11
  from sglang.srt.sampling_params import SamplingParams
6
12
 
7
13
 
@@ -27,14 +33,12 @@ class GenerateReqInput:
27
33
  return_text_in_logprobs: bool = False
28
34
  # Whether to stream output
29
35
  stream: bool = False
30
- # TODO: make all parameters a Union[List[T], T] to allow for batched requests
31
36
 
32
37
  def post_init(self):
33
-
34
- if self.text is None:
35
- assert self.input_ids is not None, "Either text or input_ids should be provided"
36
- else:
37
- assert self.input_ids is None, "Either text or input_ids should be provided"
38
+ if (self.text is None and self.input_ids is None) or (
39
+ self.text is not None and self.input_ids is not None
40
+ ):
41
+ raise ValueError("Either text or input_ids should be provided.")
38
42
 
39
43
  if self.text is not None:
40
44
  is_single = isinstance(self.text, str)
@@ -69,7 +73,8 @@ class GenerateReqInput:
69
73
  if self.rid is None:
70
74
  self.rid = [uuid.uuid4().hex for _ in range(num)]
71
75
  else:
72
- assert isinstance(self.rid, list)
76
+ if not isinstance(self.rid, list):
77
+ raise ValueError("The rid should be a list.")
73
78
 
74
79
  if self.return_logprob is None:
75
80
  self.return_logprob = [False] * num
@@ -105,13 +110,13 @@ class TokenizedGenerateReqInput:
105
110
  @dataclass
106
111
  class BatchTokenIDOut:
107
112
  rids: List[str]
108
- output_tokens: List[List[int]]
109
- output_and_jump_forward_strs: List[str]
110
- hit_stop_str: List[Optional[str]]
113
+ decoded_texts: List[str]
114
+ surr_output_ids: List[List[int]]
115
+ read_output_ids: List[List[int]]
111
116
  skip_special_tokens: List[bool]
112
117
  spaces_between_special_tokens: List[bool]
113
118
  meta_info: List[Dict]
114
- finished: List[bool]
119
+ finished_reason: List[BaseFinishReason]
115
120
 
116
121
 
117
122
  @dataclass
@@ -119,7 +124,7 @@ class BatchStrOut:
119
124
  rids: List[str]
120
125
  output_str: List[str]
121
126
  meta_info: List[Dict]
122
- finished: List[bool]
127
+ finished_reason: List[BaseFinishReason]
123
128
 
124
129
 
125
130
  @dataclass
@@ -127,6 +132,11 @@ class FlushCacheReq:
127
132
  pass
128
133
 
129
134
 
135
+ @dataclass
136
+ class AbortReq:
137
+ rid: str
138
+
139
+
130
140
  @dataclass
131
141
  class DetokenizeReqInput:
132
142
  input_ids: List[int]
@@ -1,16 +1,19 @@
1
+ """TokenizerManager is a process that tokenizes the text."""
2
+
1
3
  import asyncio
2
4
  import concurrent.futures
3
5
  import dataclasses
4
6
  import logging
5
7
  import multiprocessing as mp
6
8
  import os
7
- from typing import List
9
+ from typing import Dict, List
8
10
 
9
11
  import numpy as np
10
12
  import transformers
11
13
  import uvloop
12
14
  import zmq
13
15
  import zmq.asyncio
16
+ from fastapi import BackgroundTasks
14
17
 
15
18
  from sglang.srt.hf_transformers_utils import (
16
19
  get_config,
@@ -19,8 +22,9 @@ from sglang.srt.hf_transformers_utils import (
19
22
  get_tokenizer,
20
23
  )
21
24
  from sglang.srt.managers.io_struct import (
25
+ AbortReq,
22
26
  BatchStrOut,
23
- DetokenizeReqInput,
27
+ BatchTokenIDOut,
24
28
  FlushCacheReq,
25
29
  GenerateReqInput,
26
30
  TokenizedGenerateReqInput,
@@ -28,7 +32,8 @@ from sglang.srt.managers.io_struct import (
28
32
  from sglang.srt.mm_utils import expand2square, process_anyres_image
29
33
  from sglang.srt.sampling_params import SamplingParams
30
34
  from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
35
+ from sglang.srt.utils import is_multimodal_model, load_image
36
+ from sglang.utils import get_exception_traceback
32
37
 
33
38
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
34
39
 
@@ -42,51 +47,6 @@ class ReqState:
42
47
  event: asyncio.Event
43
48
 
44
49
 
45
- global global_processor
46
-
47
-
48
- def init_global_processor(server_args: ServerArgs):
49
- global global_processor
50
- transformers.logging.set_verbosity_error()
51
- global_processor = get_processor(
52
- server_args.tokenizer_path,
53
- tokenizer_mode=server_args.tokenizer_mode,
54
- trust_remote_code=server_args.trust_remote_code,
55
- )
56
-
57
-
58
- def get_pixel_values(
59
- image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
60
- ):
61
- try:
62
- processor = processor or global_processor
63
- image, image_size = load_image(image_data)
64
- if image_size != None:
65
- image_hash = hash(image_data)
66
- pixel_values = processor.image_processor(image)["pixel_values"]
67
- for _ in range(len(pixel_values)):
68
- pixel_values[_] = pixel_values[_].astype(np.float16)
69
- pixel_values = np.stack(pixel_values, axis=0)
70
- return pixel_values, image_hash, image_size
71
- else:
72
- image_hash = hash(image_data)
73
- if image_aspect_ratio == "pad":
74
- image = expand2square(
75
- image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
76
- )
77
- pixel_values = processor.image_processor(image)["pixel_values"][0]
78
- elif image_aspect_ratio == "anyres":
79
- pixel_values = process_anyres_image(
80
- image, processor.image_processor, image_grid_pinpoints
81
- )
82
- else:
83
- pixel_values = processor.image_processor(image)["pixel_values"][0]
84
- pixel_values = pixel_values.astype(np.float16)
85
- return pixel_values, image_hash, image.size
86
- except Exception:
87
- print("Exception in TokenizerManager:\n" + get_exception_traceback())
88
-
89
-
90
50
  class TokenizerManager:
91
51
  def __init__(
92
52
  self,
@@ -132,7 +92,7 @@ class TokenizerManager:
132
92
  )
133
93
 
134
94
  self.to_create_loop = True
135
- self.rid_to_state = {} # Dict[str -> ReqState]
95
+ self.rid_to_state: Dict[str, ReqState] = {}
136
96
 
137
97
  async def get_pixel_values(self, image_data):
138
98
  aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
@@ -153,10 +113,11 @@ class TokenizerManager:
153
113
  image_data, aspect_ratio, grid_pinpoints, self.processor
154
114
  )
155
115
 
156
- async def generate_request(self, obj: GenerateReqInput):
116
+ async def generate_request(self, obj: GenerateReqInput, request=None):
157
117
  if self.to_create_loop:
158
- await self.create_handle_loop()
118
+ self.create_handle_loop()
159
119
 
120
+ obj.post_init()
160
121
  is_single = obj.is_single
161
122
  if is_single:
162
123
  rid = obj.rid
@@ -169,7 +130,7 @@ class TokenizerManager:
169
130
  if len(input_ids) >= self.context_len:
170
131
  raise ValueError(
171
132
  f"The input ({len(input_ids)} tokens) is longer than the "
172
- f"model's context length ({self.context_len} tokens)"
133
+ f"model's context length ({self.context_len} tokens)."
173
134
  )
174
135
 
175
136
  sampling_params = SamplingParams(**obj.sampling_params)
@@ -207,23 +168,38 @@ class TokenizerManager:
207
168
  self.rid_to_state[rid] = state
208
169
 
209
170
  while True:
210
- await event.wait()
211
- out = self.convert_logprob_style(state.out_list[-1],
212
- obj.return_logprob,
213
- obj.top_logprobs_num,
214
- obj.return_text_in_logprobs)
171
+ try:
172
+ await asyncio.wait_for(event.wait(), timeout=4)
173
+ except asyncio.TimeoutError:
174
+ if request is not None and await request.is_disconnected():
175
+ self.abort_request(rid)
176
+ raise ValueError(f"Abort request {rid}")
177
+ continue
178
+
179
+ out = self.convert_logprob_style(
180
+ state.out_list[-1],
181
+ obj.return_logprob,
182
+ obj.top_logprobs_num,
183
+ obj.return_text_in_logprobs,
184
+ )
215
185
 
216
186
  if self.server_args.log_requests and state.finished:
217
187
  logger.info(f"in={obj.text}, out={out}")
218
188
 
219
- yield out
220
189
  state.out_list = []
221
190
  if state.finished:
222
191
  del self.rid_to_state[rid]
192
+
193
+ yield out
194
+
223
195
  break
196
+
224
197
  event.clear()
198
+
199
+ yield out
225
200
  else:
226
- assert obj.stream is False
201
+ if obj.stream:
202
+ raise ValueError("Do not support stream for batch mode.")
227
203
 
228
204
  if obj.input_ids is None:
229
205
  bs = len(obj.text)
@@ -273,45 +249,83 @@ class TokenizerManager:
273
249
  for i in range(bs):
274
250
  rid = obj.rid[i]
275
251
  state = self.rid_to_state[rid]
276
- await state.event.wait()
252
+
253
+ while True:
254
+ try:
255
+ await asyncio.wait_for(state.event.wait(), timeout=4)
256
+ break
257
+ except asyncio.TimeoutError:
258
+ if request is not None and await request.is_disconnected():
259
+ for rid in obj.rid:
260
+ self.abort_request(rid)
261
+ raise ValueError(f"Abort request {rid}")
262
+ continue
263
+
277
264
  output_list.append(
278
- self.convert_logprob_style(state.out_list[-1],
279
- obj.return_logprob[i],
280
- obj.top_logprobs_num[i],
281
- obj.return_text_in_logprobs))
265
+ self.convert_logprob_style(
266
+ state.out_list[-1],
267
+ obj.return_logprob[i],
268
+ obj.top_logprobs_num[i],
269
+ obj.return_text_in_logprobs,
270
+ )
271
+ )
282
272
  assert state.finished
283
273
  del self.rid_to_state[rid]
284
274
 
285
275
  yield output_list
286
276
 
287
- async def flush_cache(self):
288
- flush_cache_req = FlushCacheReq()
289
- self.send_to_router.send_pyobj(flush_cache_req)
277
+ def flush_cache(self):
278
+ req = FlushCacheReq()
279
+ self.send_to_router.send_pyobj(req)
280
+
281
+ def abort_request(self, rid):
282
+ if rid not in self.rid_to_state:
283
+ return
284
+ del self.rid_to_state[rid]
285
+ req = AbortReq(rid)
286
+ self.send_to_router.send_pyobj(req)
287
+
288
+ def create_abort_task(self, obj: GenerateReqInput):
289
+ # Abort the request if the client is disconnected.
290
+ async def abort_request():
291
+ await asyncio.sleep(3)
292
+ if obj.is_single:
293
+ self.abort_request(obj.rid)
294
+ else:
295
+ for rid in obj.rids:
296
+ self.abort_request(rid)
290
297
 
291
- async def create_handle_loop(self):
298
+ background_tasks = BackgroundTasks()
299
+ background_tasks.add_task(abort_request)
300
+ return background_tasks
301
+
302
+ def create_handle_loop(self):
292
303
  self.to_create_loop = False
293
304
  loop = asyncio.get_event_loop()
294
305
  loop.create_task(self.handle_loop())
295
306
 
296
307
  async def handle_loop(self):
297
308
  while True:
298
- recv_obj = await self.recv_from_detokenizer.recv_pyobj()
299
-
300
- if isinstance(recv_obj, BatchStrOut):
301
- for i, rid in enumerate(recv_obj.rids):
302
- recv_obj.meta_info[i]["id"] = rid
303
- out_dict = {
304
- "text": recv_obj.output_str[i],
305
- "meta_info": recv_obj.meta_info[i],
306
- }
307
- state = self.rid_to_state[rid]
308
- state.out_list.append(out_dict)
309
- state.finished = recv_obj.finished[i]
310
- state.event.set()
311
- else:
312
- raise ValueError(f"Invalid object: {recv_obj}")
313
-
314
- def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
309
+ recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
310
+ assert isinstance(recv_obj, BatchStrOut)
311
+
312
+ for i, rid in enumerate(recv_obj.rids):
313
+ state = self.rid_to_state.get(rid, None)
314
+ if state is None:
315
+ continue
316
+
317
+ recv_obj.meta_info[i]["id"] = rid
318
+ out_dict = {
319
+ "text": recv_obj.output_str[i],
320
+ "meta_info": recv_obj.meta_info[i],
321
+ }
322
+ state.out_list.append(out_dict)
323
+ state.finished = recv_obj.finished_reason[i] is not None
324
+ state.event.set()
325
+
326
+ def convert_logprob_style(
327
+ self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
328
+ ):
315
329
  if return_logprob:
316
330
  ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
317
331
  ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
@@ -320,10 +334,14 @@ class TokenizerManager:
320
334
  ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
321
335
  )
322
336
  if top_logprobs_num > 0:
323
- ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
337
+ ret["meta_info"][
338
+ "prefill_top_logprobs"
339
+ ] = self.detokenize_top_logprobs_tokens(
324
340
  ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
325
341
  )
326
- ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
342
+ ret["meta_info"][
343
+ "decode_top_logprobs"
344
+ ] = self.detokenize_top_logprobs_tokens(
327
345
  ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
328
346
  )
329
347
  return ret
@@ -344,3 +362,49 @@ class TokenizerManager:
344
362
  if t:
345
363
  top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
346
364
  return top_logprobs
365
+
366
+
367
+ global global_processor
368
+
369
+
370
+ def init_global_processor(server_args: ServerArgs):
371
+ global global_processor
372
+ transformers.logging.set_verbosity_error()
373
+ global_processor = get_processor(
374
+ server_args.tokenizer_path,
375
+ tokenizer_mode=server_args.tokenizer_mode,
376
+ trust_remote_code=server_args.trust_remote_code,
377
+ )
378
+
379
+
380
+ def get_pixel_values(
381
+ image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
382
+ ):
383
+ try:
384
+ processor = processor or global_processor
385
+ image, image_size = load_image(image_data)
386
+ if image_size != None:
387
+ image_hash = hash(image_data)
388
+ pixel_values = processor.image_processor(image)["pixel_values"]
389
+ for _ in range(len(pixel_values)):
390
+ pixel_values[_] = pixel_values[_].astype(np.float16)
391
+ pixel_values = np.stack(pixel_values, axis=0)
392
+ return pixel_values, image_hash, image_size
393
+ else:
394
+ image_hash = hash(image_data)
395
+ if image_aspect_ratio == "pad":
396
+ image = expand2square(
397
+ image,
398
+ tuple(int(x * 255) for x in processor.image_processor.image_mean),
399
+ )
400
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
401
+ elif image_aspect_ratio == "anyres":
402
+ pixel_values = process_anyres_image(
403
+ image, processor.image_processor, image_grid_pinpoints
404
+ )
405
+ else:
406
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
407
+ pixel_values = pixel_values.astype(np.float16)
408
+ return pixel_values, image_hash, image.size
409
+ except Exception:
410
+ print("Exception in TokenizerManager:\n" + get_exception_traceback())
@@ -1,5 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
+ from transformers import PretrainedConfig
4
+
3
5
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
4
6
 
5
7
 
@@ -15,11 +17,14 @@ class ModelConfig:
15
17
  self.path = path
16
18
  self.trust_remote_code = trust_remote_code
17
19
  self.revision = revision
18
- self.hf_config = get_config(self.path, trust_remote_code, revision)
19
-
20
- if model_overide_args is not None:
21
- self.hf_config.update(model_overide_args)
22
-
20
+ self.model_overide_args = model_overide_args
21
+ self.hf_config = get_config(
22
+ self.path,
23
+ trust_remote_code,
24
+ revision,
25
+ model_overide_args=model_overide_args,
26
+ )
27
+ self.hf_text_config = get_hf_text_config(self.hf_config)
23
28
  if context_length is not None:
24
29
  self.context_len = context_length
25
30
  else:
@@ -45,3 +50,76 @@ class ModelConfig:
45
50
  self.hidden_size = self.hf_config.hidden_size
46
51
  self.num_hidden_layers = self.hf_config.num_hidden_layers
47
52
  self.vocab_size = self.hf_config.vocab_size
53
+
54
+ # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
55
+ def get_total_num_kv_heads(self) -> int:
56
+ """Returns the total number of KV heads."""
57
+ # For GPTBigCode & Falcon:
58
+ # NOTE: for falcon, when new_decoder_architecture is True, the
59
+ # multi_query flag is ignored and we use n_head_kv for the number of
60
+ # KV heads.
61
+ falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
62
+ new_decoder_arch_falcon = (
63
+ self.hf_config.model_type in falcon_model_types
64
+ and getattr(self.hf_config, "new_decoder_architecture", False)
65
+ )
66
+ if not new_decoder_arch_falcon and getattr(
67
+ self.hf_text_config, "multi_query", False
68
+ ):
69
+ # Multi-query attention, only one KV head.
70
+ # Currently, tensor parallelism is not supported in this case.
71
+ return 1
72
+
73
+ # For DBRX and MPT
74
+ if self.hf_config.model_type in ["mpt"]:
75
+ if "kv_n_heads" in self.hf_config.attn_config:
76
+ return self.hf_config.attn_config["kv_n_heads"]
77
+ return self.hf_config.num_attention_heads
78
+ if self.hf_config.model_type in ["dbrx"]:
79
+ return getattr(
80
+ self.hf_config.attn_config,
81
+ "kv_n_heads",
82
+ self.hf_config.num_attention_heads,
83
+ )
84
+
85
+ attributes = [
86
+ # For Falcon:
87
+ "n_head_kv",
88
+ "num_kv_heads",
89
+ # For LLaMA-2:
90
+ "num_key_value_heads",
91
+ # For ChatGLM:
92
+ "multi_query_group_num",
93
+ ]
94
+ for attr in attributes:
95
+ num_kv_heads = getattr(self.hf_text_config, attr, None)
96
+ if num_kv_heads is not None:
97
+ return num_kv_heads
98
+
99
+ # For non-grouped-query attention models, the number of KV heads is
100
+ # equal to the number of attention heads.
101
+ return self.hf_text_config.num_attention_heads
102
+
103
+ # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
104
+ def get_num_kv_heads(self, tensor_parallel_size) -> int:
105
+ """Returns the number of KV heads per GPU."""
106
+ total_num_kv_heads = self.get_total_num_kv_heads()
107
+ # If tensor parallelism is used, we divide the number of KV heads by
108
+ # the tensor parallel size. We will replicate the KV heads in the
109
+ # case where the number of KV heads is smaller than the tensor
110
+ # parallel size so each GPU has at least one KV head.
111
+ return max(1, total_num_kv_heads // tensor_parallel_size)
112
+
113
+
114
+ def get_hf_text_config(config: PretrainedConfig):
115
+ """Get the "sub" config relevant to llm for multi modal models.
116
+ No op for pure text models.
117
+ """
118
+ if hasattr(config, "text_config"):
119
+ # The code operates under the assumption that text_config should have
120
+ # `num_attention_heads` (among others). Assert here to fail early
121
+ # if transformers config doesn't align with this assumption.
122
+ assert hasattr(config.text_config, "num_attention_heads")
123
+ return config.text_config
124
+ else:
125
+ return config