sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -27,44 +27,33 @@ class Scheduler:
27
27
  return forward_queue
28
28
  elif self.schedule_heuristic == "fcfs":
29
29
  return forward_queue
30
- elif self.schedule_heuristic == "weight":
30
+ elif self.schedule_heuristic == "dfs-weight":
31
31
  last_node_to_reqs = defaultdict(list)
32
32
  for req in forward_queue:
33
33
  last_node_to_reqs[req.last_node].append(req)
34
- for node in last_node_to_reqs:
35
- last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
36
34
 
37
35
  node_to_weight = defaultdict(int)
38
- self._calc_weight_recursive(
39
- self.tree_cache.root_node, last_node_to_reqs, node_to_weight
40
- )
36
+ for node in last_node_to_reqs:
37
+ node_to_weight[node] = len(last_node_to_reqs[node])
38
+ self.calc_weight(self.tree_cache.root_node, node_to_weight)
41
39
 
42
- tmp_queue = []
43
- self._get_weight_priority_recursive(
44
- self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
40
+ q = []
41
+ self.get_dfs_priority(
42
+ self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
45
43
  )
46
- assert len(tmp_queue) == len(forward_queue)
47
- return tmp_queue
44
+ assert len(q) == len(forward_queue)
45
+ return q
48
46
  else:
49
47
  raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
50
48
 
51
- def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
52
- node_to_weight[cur_node] = 1
53
- if cur_node in last_node_to_reqs:
54
- node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
49
+ def calc_weight(self, cur_node, node_to_weight):
55
50
  for child in cur_node.children.values():
56
- self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
51
+ self.calc_weight(child, node_to_weight)
57
52
  node_to_weight[cur_node] += node_to_weight[child]
58
53
 
59
- def _get_weight_priority_recursive(
60
- self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
61
- ):
62
- visit_list = [child for child in cur_node.children.values()]
63
- visit_list.sort(key=lambda x: -node_to_wight[x])
64
- # for node in visit_list:
65
- # print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
66
- for child in visit_list:
67
- self._get_weight_priority_recursive(
68
- child, node_to_wight, last_node_to_reqs, tmp_queue
69
- )
70
- tmp_queue.extend(last_node_to_reqs[cur_node])
54
+ def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
55
+ childs = [child for child in cur_node.children.values()]
56
+ childs.sort(key=lambda x: -node_to_priority[x])
57
+ for child in childs:
58
+ self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
59
+ q.extend(last_node_to_reqs[cur_node])
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import concurrent.futures
3
3
  import dataclasses
4
+ import logging
4
5
  import multiprocessing as mp
5
6
  import os
6
7
  from typing import List
@@ -10,6 +11,7 @@ import transformers
10
11
  import uvloop
11
12
  import zmq
12
13
  import zmq.asyncio
14
+
13
15
  from sglang.srt.hf_transformers_utils import (
14
16
  get_config,
15
17
  get_context_length,
@@ -30,13 +32,14 @@ from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_
30
32
 
31
33
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
32
34
 
35
+ logger = logging.getLogger(__name__)
36
+
33
37
 
34
38
  @dataclasses.dataclass
35
39
  class ReqState:
36
40
  out_list: List
37
41
  finished: bool
38
42
  event: asyncio.Event
39
- lock: asyncio.Lock
40
43
 
41
44
 
42
45
  global global_processor
@@ -57,21 +60,29 @@ def get_pixel_values(
57
60
  ):
58
61
  try:
59
62
  processor = processor or global_processor
60
- image = load_image(image_data)
61
- image_hash = hash(image_data)
62
- if image_aspect_ratio == "pad":
63
- image = expand2square(
64
- image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
65
- )
66
- pixel_values = processor.image_processor(image)["pixel_values"][0]
67
- elif image_aspect_ratio == "anyres":
68
- pixel_values = process_anyres_image(
69
- image, processor.image_processor, image_grid_pinpoints
70
- )
63
+ image, image_size = load_image(image_data)
64
+ if image_size != None:
65
+ image_hash = hash(image_data)
66
+ pixel_values = processor.image_processor(image)["pixel_values"]
67
+ for _ in range(len(pixel_values)):
68
+ pixel_values[_] = pixel_values[_].astype(np.float16)
69
+ pixel_values = np.stack(pixel_values, axis=0)
70
+ return pixel_values, image_hash, image_size
71
71
  else:
72
- pixel_values = processor.image_processor(image)["pixel_values"][0]
73
- pixel_values = pixel_values.astype(np.float16)
74
- return pixel_values, image_hash, image.size
72
+ image_hash = hash(image_data)
73
+ if image_aspect_ratio == "pad":
74
+ image = expand2square(
75
+ image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
76
+ )
77
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
78
+ elif image_aspect_ratio == "anyres":
79
+ pixel_values = process_anyres_image(
80
+ image, processor.image_processor, image_grid_pinpoints
81
+ )
82
+ else:
83
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
84
+ pixel_values = pixel_values.astype(np.float16)
85
+ return pixel_values, image_hash, image.size
75
86
  except Exception:
76
87
  print("Exception in TokenizerManager:\n" + get_exception_traceback())
77
88
 
@@ -81,6 +92,7 @@ class TokenizerManager:
81
92
  self,
82
93
  server_args: ServerArgs,
83
94
  port_args: PortArgs,
95
+ model_overide_args: dict = None,
84
96
  ):
85
97
  self.server_args = server_args
86
98
 
@@ -93,9 +105,10 @@ class TokenizerManager:
93
105
 
94
106
  self.model_path = server_args.model_path
95
107
  self.hf_config = get_config(
96
- self.model_path, trust_remote_code=server_args.trust_remote_code
108
+ self.model_path,
109
+ trust_remote_code=server_args.trust_remote_code,
110
+ model_overide_args=model_overide_args,
97
111
  )
98
-
99
112
  self.context_len = get_context_length(self.hf_config)
100
113
 
101
114
  if is_multimodal_model(self.model_path):
@@ -144,11 +157,21 @@ class TokenizerManager:
144
157
  if self.to_create_loop:
145
158
  await self.create_handle_loop()
146
159
 
147
- is_single = isinstance(obj.text, str)
148
-
160
+ is_single = obj.is_single
149
161
  if is_single:
150
162
  rid = obj.rid
151
- input_ids = self.tokenizer.encode(obj.text)
163
+
164
+ if obj.input_ids is None:
165
+ input_ids = self.tokenizer.encode(obj.text)
166
+ else:
167
+ input_ids = obj.input_ids
168
+
169
+ if len(input_ids) >= self.context_len:
170
+ raise ValueError(
171
+ f"The input ({len(input_ids)} tokens) is longer than the "
172
+ f"model's context length ({self.context_len} tokens)"
173
+ )
174
+
152
175
  sampling_params = SamplingParams(**obj.sampling_params)
153
176
  if sampling_params.max_new_tokens != 0:
154
177
  sampling_params.normalize(self.tokenizer)
@@ -174,18 +197,26 @@ class TokenizerManager:
174
197
  sampling_params=sampling_params,
175
198
  return_logprob=obj.return_logprob,
176
199
  logprob_start_len=obj.logprob_start_len,
200
+ top_logprobs_num=obj.top_logprobs_num,
177
201
  stream=obj.stream,
178
202
  )
179
203
  self.send_to_router.send_pyobj(tokenized_obj)
180
204
 
181
- lock = asyncio.Lock()
182
205
  event = asyncio.Event()
183
- state = ReqState([], False, event, lock)
206
+ state = ReqState([], False, event)
184
207
  self.rid_to_state[rid] = state
185
208
 
186
209
  while True:
187
210
  await event.wait()
188
- yield state.out_list[-1]
211
+ out = self.convert_logprob_style(state.out_list[-1],
212
+ obj.return_logprob,
213
+ obj.top_logprobs_num,
214
+ obj.return_text_in_logprobs)
215
+
216
+ if self.server_args.log_requests and state.finished:
217
+ logger.info(f"in={obj.text}, out={out}")
218
+
219
+ yield out
189
220
  state.out_list = []
190
221
  if state.finished:
191
222
  del self.rid_to_state[rid]
@@ -193,10 +224,22 @@ class TokenizerManager:
193
224
  event.clear()
194
225
  else:
195
226
  assert obj.stream is False
196
- bs = len(obj.text)
227
+
228
+ if obj.input_ids is None:
229
+ bs = len(obj.text)
230
+ else:
231
+ bs = len(obj.input_ids)
232
+
197
233
  for i in range(bs):
198
234
  rid = obj.rid[i]
199
- input_ids = self.tokenizer.encode(obj.text[i])
235
+
236
+ if obj.input_ids is None:
237
+ input_text = obj.text[i]
238
+ input_ids = self.tokenizer.encode(obj.text[i])
239
+ else:
240
+ input_text = None
241
+ input_ids = obj.input_ids[i]
242
+
200
243
  sampling_params = SamplingParams(**obj.sampling_params[i])
201
244
  if sampling_params.max_new_tokens != 0:
202
245
  sampling_params.normalize(self.tokenizer)
@@ -209,7 +252,7 @@ class TokenizerManager:
209
252
  )
210
253
  tokenized_obj = TokenizedGenerateReqInput(
211
254
  rid=rid,
212
- input_text=obj.text[i],
255
+ input_text=input_text,
213
256
  input_ids=input_ids,
214
257
  pixel_values=pixel_values,
215
258
  image_hash=image_hash,
@@ -217,13 +260,13 @@ class TokenizerManager:
217
260
  sampling_params=sampling_params,
218
261
  return_logprob=obj.return_logprob[i],
219
262
  logprob_start_len=obj.logprob_start_len[i],
263
+ top_logprobs_num=obj.top_logprobs_num[i],
220
264
  stream=obj.stream,
221
265
  )
222
266
  self.send_to_router.send_pyobj(tokenized_obj)
223
267
 
224
- lock = asyncio.Lock()
225
268
  event = asyncio.Event()
226
- state = ReqState([], False, event, lock)
269
+ state = ReqState([], False, event)
227
270
  self.rid_to_state[rid] = state
228
271
 
229
272
  output_list = []
@@ -231,16 +274,16 @@ class TokenizerManager:
231
274
  rid = obj.rid[i]
232
275
  state = self.rid_to_state[rid]
233
276
  await state.event.wait()
234
- output_list.append(state.out_list[-1])
277
+ output_list.append(
278
+ self.convert_logprob_style(state.out_list[-1],
279
+ obj.return_logprob[i],
280
+ obj.top_logprobs_num[i],
281
+ obj.return_text_in_logprobs))
235
282
  assert state.finished
236
283
  del self.rid_to_state[rid]
237
284
 
238
285
  yield output_list
239
286
 
240
- async def detokenize(self, obj: DetokenizeReqInput):
241
- token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
242
- return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
243
-
244
287
  async def flush_cache(self):
245
288
  flush_cache_req = FlushCacheReq()
246
289
  self.send_to_router.send_pyobj(flush_cache_req)
@@ -267,3 +310,37 @@ class TokenizerManager:
267
310
  state.event.set()
268
311
  else:
269
312
  raise ValueError(f"Invalid object: {recv_obj}")
313
+
314
+ def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
315
+ if return_logprob:
316
+ ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
317
+ ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
318
+ )
319
+ ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
320
+ ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
321
+ )
322
+ if top_logprobs_num > 0:
323
+ ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
324
+ ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
325
+ )
326
+ ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
327
+ ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
328
+ )
329
+ return ret
330
+
331
+ def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
332
+ if not decode_to_text:
333
+ return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
334
+
335
+ token_ids = [tid for _, tid in token_logprobs]
336
+ token_texts = self.tokenizer.batch_decode(token_ids)
337
+ return [
338
+ (logprob, token_id, token_text)
339
+ for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
340
+ ]
341
+
342
+ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
343
+ for i, t in enumerate(top_logprobs):
344
+ if t:
345
+ top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
346
+ return top_logprobs
sglang/srt/memory_pool.py CHANGED
@@ -31,9 +31,6 @@ class ReqToTokenPool:
31
31
  self.can_use_mem_size += free_index.shape[0]
32
32
  self.mem_state[free_index] = 1
33
33
 
34
- # if self.can_use_mem_size == len(self.mem_state):
35
- # print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
36
-
37
34
  def clear(self):
38
35
  self.mem_state.fill_(1)
39
36
  self.can_use_mem_size = len(self.mem_state)
@@ -42,7 +39,7 @@ class ReqToTokenPool:
42
39
  class TokenToKVPool:
43
40
  def __init__(self, size, dtype, head_num, head_dim, layer_num):
44
41
  self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
45
- self.alloc_ct = 0
42
+ self.total_ref_ct = 0
46
43
 
47
44
  # [size, key/value, head_num, head_dim] for each layer
48
45
  self.kv_data = [
@@ -83,9 +80,6 @@ class TokenToKVPool:
83
80
  self.add_refs(select_index)
84
81
  return select_index.to(torch.int32), start_loc, start_loc + need_size
85
82
 
86
- def free(self, free_index):
87
- return self.decrease_refs(free_index)
88
-
89
83
  def used_size(self):
90
84
  return len(torch.nonzero(self.mem_state).squeeze(1))
91
85
 
@@ -93,20 +87,17 @@ class TokenToKVPool:
93
87
  return torch.sum(self.mem_state == 0).item()
94
88
 
95
89
  def add_refs(self, token_index: torch.Tensor):
96
- self.alloc_ct += len(token_index)
90
+ self.total_ref_ct += len(token_index)
97
91
  self.mem_state[token_index] += 1
98
92
 
99
- def decrease_refs(self, token_index: torch.Tensor):
100
- self.alloc_ct -= len(token_index)
93
+ def dec_refs(self, token_index: torch.Tensor):
94
+ self.total_ref_ct -= len(token_index)
101
95
  self.mem_state[token_index] -= 1
102
96
 
103
97
  num_freed = torch.sum(self.mem_state[token_index] == 0)
104
98
 
105
- # if self.alloc_ct == 0:
106
- # print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
107
-
108
99
  return num_freed
109
100
 
110
101
  def clear(self):
111
102
  self.mem_state.fill_(0)
112
- self.alloc_ct = 0
103
+ self.total_ref_ct = 0
@@ -10,12 +10,16 @@ class ModelConfig:
10
10
  trust_remote_code: bool = True,
11
11
  revision: Optional[str] = None,
12
12
  context_length: Optional[int] = None,
13
+ model_overide_args: Optional[dict] = None,
13
14
  ) -> None:
14
15
  self.path = path
15
16
  self.trust_remote_code = trust_remote_code
16
17
  self.revision = revision
17
18
  self.hf_config = get_config(self.path, trust_remote_code, revision)
18
19
 
20
+ if model_overide_args is not None:
21
+ self.hf_config.update(model_overide_args)
22
+
19
23
  if context_length is not None:
20
24
  self.context_len = context_length
21
25
  else:
@@ -29,6 +33,13 @@ class ModelConfig:
29
33
  )
30
34
  self.num_attention_heads = self.hf_config.num_attention_heads
31
35
  self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
36
+
37
+ # for Dbrx and MPT models
38
+ if self.hf_config.model_type in ["dbrx", "mpt"]:
39
+ self.num_key_value_heads = getattr(
40
+ self.hf_config.attn_config, "kv_n_heads", None
41
+ )
42
+
32
43
  if self.num_key_value_heads is None:
33
44
  self.num_key_value_heads = self.num_attention_heads
34
45
  self.hidden_size = self.hf_config.hidden_size