sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import re
3
4
  from abc import ABC, abstractmethod
4
5
  from json import JSONDecodeError, JSONDecoder
@@ -8,6 +9,8 @@ import partial_json_parser
8
9
  from partial_json_parser.core.options import Allow
9
10
  from pydantic import BaseModel, Field
10
11
 
12
+ logger = logging.getLogger(__name__)
13
+
11
14
  TOOLS_TAG_LIST = [
12
15
  "<|plugin|>",
13
16
  "<function=",
@@ -88,17 +91,43 @@ class BaseFormatDetector:
88
91
  self.bot_token = ""
89
92
  self.eot_token = ""
90
93
 
91
- def parse_base_json(self, action: Dict, tools: List[Function]):
92
- name, parameters = action["name"], json.dumps(
93
- action.get("parameters", action.get("arguments", {})),
94
- ensure_ascii=False,
95
- )
96
- tool_index = [tool.function.name for tool in tools].index(name)
97
- tool_call_item = ToolCallItem(
98
- tool_index=tool_index, name=name, parameters=parameters
99
- )
100
- calls = [tool_call_item]
101
- return calls
94
+ def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
95
+ tool_indices = {
96
+ tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
97
+ }
98
+ if not isinstance(action, list):
99
+ name = action.get("name")
100
+ if not name or name not in tool_indices:
101
+ logger.warning(f"Model attempted to call undefined function: {name}")
102
+ return []
103
+
104
+ return [
105
+ ToolCallItem(
106
+ tool_index=tool_indices[name],
107
+ name=name,
108
+ parameters=json.dumps(
109
+ action.get("parameters") or action.get("arguments", {}),
110
+ ensure_ascii=False,
111
+ ),
112
+ )
113
+ ]
114
+
115
+ results = []
116
+ for act in action:
117
+ name = act.get("name")
118
+ if name and name in tool_indices:
119
+ results.append(
120
+ ToolCallItem(
121
+ tool_index=tool_indices[name],
122
+ name=name,
123
+ parameters=json.dumps(
124
+ act.get("parameters") or act.get("arguments", {}),
125
+ ensure_ascii=False,
126
+ ),
127
+ )
128
+ )
129
+
130
+ return results
102
131
 
103
132
  def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
104
133
  """
@@ -112,9 +141,7 @@ class BaseFormatDetector:
112
141
  self, new_text: str, tools: List[Function]
113
142
  ) -> StreamingParseResult:
114
143
  """
115
- Streaming incremental parsing, referencing the logic of Llama32Detector.
116
- We partially parse JSON within <tool_call>...</tool_call>, and handle
117
- incremental argument output.
144
+ Streaming incremental parsing with tool validation.
118
145
  """
119
146
  # Append new text to buffer
120
147
  self._buffer += new_text
@@ -125,17 +152,19 @@ class BaseFormatDetector:
125
152
  new_text = new_text.replace(self.eot_token, "")
126
153
  return StreamingParseResult(normal_text=new_text)
127
154
 
128
- # bit mask flags for partial JSON parsing. If the name hasn't been
129
- # sent yet, don't allow sending
130
- # an incomplete string since OpenAI only ever (as far as I have
131
- # seen) allows sending the entire tool/ function name at once.
155
+ # Build tool indices if not already built
156
+ if not hasattr(self, "_tool_indices"):
157
+ self._tool_indices = {
158
+ tool.function.name: i
159
+ for i, tool in enumerate(tools)
160
+ if tool.function and tool.function.name
161
+ }
162
+
132
163
  flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
133
164
  try:
134
165
  tool_call_arr = []
135
166
  is_complete = []
136
167
  try:
137
- # depending on the prompt format the Llama model may or may not
138
- # prefix the output with the <|python_tag|> token
139
168
  start_idx = (
140
169
  len(self.bot_token)
141
170
  if current_text.startswith(self.bot_token)
@@ -149,8 +178,18 @@ class BaseFormatDetector:
149
178
  _is_complete_json(current_text[start_idx : start_idx + end_idx])
150
179
  )
151
180
  start_idx += end_idx + len("; ")
152
- # depending on the prompt Llama can use
153
- # either arguments or parameters
181
+
182
+ # Validate tool name if present
183
+ if "name" in obj and obj["name"] not in self._tool_indices:
184
+ # Invalid tool name - reset state
185
+ self._buffer = ""
186
+ self.current_tool_id = -1
187
+ self.current_tool_name_sent = False
188
+ if self.streamed_args_for_tool:
189
+ self.streamed_args_for_tool.pop()
190
+ return StreamingParseResult()
191
+
192
+ # Handle parameters/arguments consistency
154
193
  if "parameters" in obj:
155
194
  assert (
156
195
  "arguments" not in obj
@@ -159,29 +198,17 @@ class BaseFormatDetector:
159
198
  tool_call_arr.append(obj)
160
199
 
161
200
  except partial_json_parser.core.exceptions.MalformedJSON:
162
- # not enough tokens to parse into JSON yet
163
201
  return StreamingParseResult()
164
202
 
165
- # select as the current tool call the one we're on the state at
166
- current_tool_call: Dict = (
167
- tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
168
- )
169
-
170
- # case -- if no tokens have been streamed for the tool, e.g.
171
- # only the array brackets, stream nothing
172
203
  if len(tool_call_arr) == 0:
173
204
  return StreamingParseResult()
174
205
 
175
- # case: we are starting a new tool in the array
176
- # -> array has > 0 length AND length has moved past cursor
177
- elif (
178
- len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
179
- ):
206
+ current_tool_call: Dict = (
207
+ tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
208
+ )
180
209
 
181
- # if we're moving on to a new call, first make sure we
182
- # haven't missed anything in the previous one that was
183
- # auto-generated due to JSON completions, but wasn't
184
- # streamed to the client yet.
210
+ # Handle new tool in array
211
+ if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
185
212
  if self.current_tool_id >= 0:
186
213
  cur_arguments = current_tool_call.get("arguments")
187
214
  if cur_arguments:
@@ -190,7 +217,6 @@ class BaseFormatDetector:
190
217
  argument_diff = cur_args_json[sent:]
191
218
 
192
219
  res = StreamingParseResult(
193
- normal_text=None,
194
220
  calls=[
195
221
  ToolCallItem(
196
222
  tool_index=self.current_tool_id,
@@ -206,23 +232,20 @@ class BaseFormatDetector:
206
232
  res = StreamingParseResult()
207
233
  else:
208
234
  res = StreamingParseResult()
209
- # re-set stuff pertaining to progress in the current tool
235
+
210
236
  self.current_tool_id = len(tool_call_arr) - 1
211
237
  self.current_tool_name_sent = False
212
238
  self.streamed_args_for_tool.append("")
213
- print("starting on new tool %d", self.current_tool_id)
214
239
  return res
215
240
 
216
- # if the current tool name hasn't been sent, send if available
217
- # - otherwise send nothing
241
+ # Handle tool name
218
242
  elif not self.current_tool_name_sent:
219
243
  function_name = current_tool_call.get("name")
220
- if function_name:
244
+ if function_name and function_name in self._tool_indices:
221
245
  res = StreamingParseResult(
222
- normal_text=None,
223
246
  calls=[
224
247
  ToolCallItem(
225
- tool_index=self.current_tool_id,
248
+ tool_index=self._tool_indices[function_name],
226
249
  name=function_name,
227
250
  parameters="",
228
251
  )
@@ -232,8 +255,7 @@ class BaseFormatDetector:
232
255
  else:
233
256
  res = StreamingParseResult()
234
257
 
235
- # now we know we're on the same tool call and we're streaming
236
- # arguments
258
+ # Handle streaming arguments
237
259
  else:
238
260
  cur_arguments = current_tool_call.get("arguments")
239
261
  res = StreamingParseResult()
@@ -250,13 +272,12 @@ class BaseFormatDetector:
250
272
  argument_diff = cur_args_json[sent:]
251
273
  self._buffer = ""
252
274
  self.prev_tool_call_arr[self.current_tool_id].clear()
253
- self.current_tool_name_sent: bool = False
275
+ self.current_tool_name_sent = False
254
276
  self.streamed_args_for_tool[self.current_tool_id] = ""
255
277
 
256
278
  elif prev_arguments:
257
279
  prev_args_json = json.dumps(prev_arguments)
258
280
  if cur_args_json != prev_args_json:
259
-
260
281
  prefix = _find_common_prefix(prev_args_json, cur_args_json)
261
282
  argument_diff = prefix[sent:]
262
283
 
@@ -279,8 +300,7 @@ class BaseFormatDetector:
279
300
  return res
280
301
 
281
302
  except Exception as e:
282
- print(e)
283
- # Skipping chunk as a result of tool streaming extraction error
303
+ logger.error(f"Error in parse_streaming_increment: {e}")
284
304
  return StreamingParseResult()
285
305
 
286
306
 
@@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector):
372
392
  Detector for Llama 3.2 models.
373
393
  Assumes function call format:
374
394
  <|python_tag|>{"name":"xxx", "arguments":{...}}
375
- Does not require a closing tag "</python_tag|>",
376
- relies on json.loads(...) success to determine if JSON is complete.
377
395
  """
378
396
 
379
397
  def __init__(self):
380
- """
381
- Initializes the detector with necessary state variables.
382
- """
383
398
  super().__init__()
384
399
  self.bot_token = "<|python_tag|>"
385
400
 
386
401
  def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
387
- """
388
- One-time parsing: Detects and parses tool calls in the provided text.
389
-
390
- :param text: The complete text to parse.
391
- :param tools: List of available tools.
392
- :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
393
- """
394
-
402
+ """Parse function calls from text, handling multiple JSON objects."""
395
403
  if "<|python_tag|>" not in text:
396
404
  return []
397
- _, action = text.split("<|python_tag|>")
398
- action = json.loads(action)
399
- return self.parse_base_json(action, tools)
405
+
406
+ _, action_text = text.split("<|python_tag|>")
407
+
408
+ # Split by semicolon and process each part
409
+ json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
410
+
411
+ all_actions = []
412
+ for part in json_parts:
413
+ try:
414
+ # Parse each individual JSON object
415
+ action = json.loads(part)
416
+ all_actions.append(action)
417
+ except json.JSONDecodeError as e:
418
+ logger.warning(f"Failed to parse JSON part: {part}")
419
+ logger.warning(f"JSON parse error: {str(e)}")
420
+ continue
421
+
422
+ # Only process if we found valid JSON objects
423
+ if all_actions:
424
+ return self.parse_base_json(all_actions, tools)
425
+
426
+ return []
400
427
 
401
428
 
402
429
  class MultiFormatParser:
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
17
17
  def __init__(self, model_runner: ModelRunner):
18
18
  # Lazy import to avoid the initialization of cuda context
19
19
  from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
20
+ extend_attention_fwd,
20
21
  flash_decode_attention_fwd,
21
22
  flash_decode_sparse_attention_fwd,
22
23
  )
23
- from sglang.srt.layers.attention.triton_ops.extend_attention import (
24
- extend_attention_fwd,
25
- )
26
24
 
27
25
  super().__init__()
28
26
 
@@ -37,6 +37,9 @@ class TritonAttnBackend(AttentionBackend):
37
37
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
38
38
  )
39
39
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
40
+ self.qo_indptr = torch.zeros(
41
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
42
+ )
40
43
 
41
44
  self.num_head = (
42
45
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -54,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
54
57
  def init_forward_metadata(self, forward_batch: ForwardBatch):
55
58
  """Init auxiliary variables for triton attention backend."""
56
59
 
60
+ bs = forward_batch.batch_size
61
+ kv_indptr = self.kv_indptr
62
+
57
63
  if forward_batch.forward_mode.is_decode():
58
64
  attn_logits = torch.empty(
59
65
  (
@@ -68,31 +74,62 @@ class TritonAttnBackend(AttentionBackend):
68
74
 
69
75
  max_extend_len = None
70
76
 
71
- kv_indptr = self.kv_indptr
72
- bs = len(forward_batch.req_pool_indices)
73
77
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
74
78
  kv_indptr = kv_indptr[: bs + 1]
75
79
  kv_indices = torch.empty(
76
- forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
80
+ forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
77
81
  )
78
82
  create_flashinfer_kv_indices_triton[(bs,)](
79
- forward_batch.req_to_token_pool.req_to_token,
83
+ self.req_to_token,
80
84
  forward_batch.req_pool_indices,
81
85
  forward_batch.seq_lens,
82
86
  kv_indptr,
83
87
  None,
84
88
  kv_indices,
85
- forward_batch.req_to_token_pool.req_to_token.stride(0),
89
+ self.req_to_token.stride(0),
86
90
  )
87
91
 
92
+ qo_indptr = None
93
+ custom_mask = None
94
+ mask_offsets = None
88
95
  else:
96
+ kv_indptr[1 : bs + 1] = torch.cumsum(
97
+ forward_batch.extend_prefix_lens, dim=0
98
+ )
99
+ kv_indptr = kv_indptr[: bs + 1]
100
+ kv_indices = torch.empty(
101
+ forward_batch.extend_prefix_lens.sum().item(),
102
+ dtype=torch.int32,
103
+ device=self.device,
104
+ )
105
+ create_flashinfer_kv_indices_triton[(bs,)](
106
+ self.req_to_token,
107
+ forward_batch.req_pool_indices,
108
+ forward_batch.extend_prefix_lens,
109
+ kv_indptr,
110
+ None,
111
+ kv_indices,
112
+ self.req_to_token.stride(0),
113
+ )
114
+
115
+ qo_indptr = self.qo_indptr
116
+ qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
117
+ qo_indptr = qo_indptr[: bs + 1]
118
+ custom_mask = None
119
+ mask_offsets = None
120
+
89
121
  attn_logits = None
90
122
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
91
123
 
92
- kv_indptr = None
93
- kv_indices = None
94
-
95
- self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
124
+ self.forward_metadata = (
125
+ attn_logits,
126
+ max_extend_len,
127
+ kv_indptr,
128
+ kv_indices,
129
+ qo_indptr,
130
+ custom_mask,
131
+ mask_offsets,
132
+ )
96
133
 
97
134
  def init_cuda_graph_state(self, max_bs: int):
98
135
  self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
@@ -144,6 +181,9 @@ class TritonAttnBackend(AttentionBackend):
144
181
  None,
145
182
  kv_indptr,
146
183
  kv_indices,
184
+ None,
185
+ None,
186
+ None,
147
187
  )
148
188
 
149
189
  def init_forward_metadata_replay_cuda_graph(
@@ -197,7 +237,15 @@ class TritonAttnBackend(AttentionBackend):
197
237
  layer, forward_batch.out_cache_loc, k, v
198
238
  )
199
239
 
200
- _, max_extend_len, _, _ = self.forward_metadata
240
+ (
241
+ _,
242
+ max_extend_len,
243
+ kv_indptr,
244
+ kv_indices,
245
+ qo_indptr,
246
+ custom_mask,
247
+ mask_offsets,
248
+ ) = self.forward_metadata
201
249
  self.extend_attention_fwd(
202
250
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
203
251
  k.contiguous(),
@@ -205,11 +253,11 @@ class TritonAttnBackend(AttentionBackend):
205
253
  o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
206
254
  forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
207
255
  forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
208
- forward_batch.req_to_token_pool.req_to_token,
209
- forward_batch.req_pool_indices,
210
- forward_batch.seq_lens,
211
- forward_batch.extend_seq_lens,
212
- forward_batch.extend_start_loc,
256
+ qo_indptr,
257
+ kv_indptr,
258
+ kv_indices,
259
+ custom_mask,
260
+ mask_offsets,
213
261
  max_extend_len,
214
262
  layer.scaling,
215
263
  layer.logit_cap,
@@ -235,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
235
283
  else:
236
284
  o = torch.empty_like(q)
237
285
 
238
- attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
286
+ attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
239
287
 
240
288
  if save_kv_cache:
241
289
  forward_batch.token_to_kv_pool.set_kv_buffer(