sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,619 @@
1
+ """Meta data for requests and batches"""
2
+ from dataclasses import dataclass
3
+ from enum import IntEnum, auto
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from sglang.srt.managers.controller.radix_cache import RadixCache
10
+ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
11
+
12
+
13
+ class ForwardMode(IntEnum):
14
+ PREFILL = auto()
15
+ EXTEND = auto()
16
+ DECODE = auto()
17
+
18
+
19
+ class BaseFinishReason:
20
+ def __init__(self, is_error: bool = False):
21
+ self.is_error = is_error
22
+
23
+ def __str__(self):
24
+ raise NotImplementedError("Subclasses must implement this method")
25
+
26
+
27
+ class FINISH_MATCHED_TOKEN(BaseFinishReason):
28
+ def __init__(self, matched: int | List[int]):
29
+ super().__init__()
30
+ self.matched = matched
31
+
32
+ def __str__(self) -> str:
33
+ return f"FINISH_MATCHED_TOKEN: {self.matched}"
34
+
35
+
36
+ class FINISH_LENGTH(BaseFinishReason):
37
+ def __init__(self, length: int):
38
+ super().__init__()
39
+ self.length = length
40
+
41
+ def __str__(self) -> str:
42
+ return f"FINISH_LENGTH: {self.length}"
43
+
44
+
45
+ class FINISH_MATCHED_STR(BaseFinishReason):
46
+ def __init__(self, matched: str):
47
+ super().__init__()
48
+ self.matched = matched
49
+
50
+ def __str__(self) -> str:
51
+ return f"FINISH_MATCHED_STR: {self.matched}"
52
+
53
+
54
+ class FINISH_ABORT(BaseFinishReason):
55
+ def __init__(self):
56
+ super().__init__(is_error=True)
57
+
58
+ def __str__(self) -> str:
59
+ return "FINISH_ABORT"
60
+
61
+
62
+ class Req:
63
+ def __init__(self, rid, origin_input_text, origin_input_ids):
64
+ self.rid = rid
65
+ self.origin_input_text = origin_input_text
66
+ self.origin_input_ids = origin_input_ids
67
+ self.origin_input_ids_unpadded = origin_input_ids # before image padding
68
+ self.prev_output_str = ""
69
+ self.prev_output_ids = []
70
+ self.output_ids = []
71
+ self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
72
+
73
+ # The number of decoded tokens for token usage report. Note that
74
+ # this does not include the jump forward tokens.
75
+ self.completion_tokens_wo_jump_forward = 0
76
+
77
+ # For vision input
78
+ self.pixel_values = None
79
+ self.image_size = None
80
+ self.image_offset = 0
81
+ self.pad_value = None
82
+
83
+ # Sampling parameters
84
+ self.sampling_params = None
85
+ self.stream = False
86
+
87
+ self.tokenizer = None
88
+
89
+ # Check finish
90
+ self.finished_reason = None
91
+
92
+ # Prefix info
93
+ self.extend_input_len = 0
94
+ self.prefix_indices = []
95
+ self.last_node = None
96
+
97
+ # Logprobs
98
+ self.return_logprob = False
99
+ self.logprob_start_len = 0
100
+ self.top_logprobs_num = 0
101
+ self.normalized_prompt_logprob = None
102
+ self.prefill_token_logprobs = None
103
+ self.prefill_top_logprobs = None
104
+ self.decode_token_logprobs = []
105
+ self.decode_top_logprobs = []
106
+ # The tokens is prefilled but need to be considered as decode tokens
107
+ # and should be updated for the decode logprobs
108
+ self.last_update_decode_tokens = 0
109
+
110
+ # Constrained decoding
111
+ self.regex_fsm = None
112
+ self.regex_fsm_state = 0
113
+ self.jump_forward_map = None
114
+
115
+ # whether request reached finished condition
116
+ def finished(self) -> bool:
117
+ return self.finished_reason is not None
118
+
119
+ def partial_decode(self, ids):
120
+ first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
121
+ first_token = (
122
+ first_token.decode() if isinstance(first_token, bytes) else first_token
123
+ )
124
+ return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
125
+
126
+ def max_new_tokens(self):
127
+ return self.sampling_params.max_new_tokens
128
+
129
+ def check_finished(self):
130
+ if self.finished():
131
+ return
132
+
133
+ if (
134
+ len(self.prev_output_ids) + len(self.output_ids)
135
+ >= self.sampling_params.max_new_tokens
136
+ ):
137
+ self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids))
138
+ return
139
+
140
+ if (
141
+ self.output_ids[-1] == self.tokenizer.eos_token_id
142
+ and not self.sampling_params.ignore_eos
143
+ ):
144
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id)
145
+ return
146
+
147
+ if len(self.sampling_params.stop_strs) > 0:
148
+ tail_str = self.tokenizer.decode(
149
+ self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
150
+ )
151
+
152
+ for stop_str in self.sampling_params.stop_strs:
153
+ # FIXME: (minor) try incremental match in prev_output_str
154
+ if stop_str in tail_str or stop_str in self.prev_output_str:
155
+ self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
156
+ return
157
+
158
+ def jump_forward_and_retokenize(self, jump_forward_str, next_state):
159
+ # FIXME: This logic does not really solve the problem of determining whether
160
+ # there should be a leading space.
161
+ cur_output_str = self.partial_decode(self.output_ids)
162
+
163
+ # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
164
+ if self.origin_input_text is None:
165
+ # Recovering text can only use unpadded ids
166
+ self.origin_input_text = self.tokenizer.decode(
167
+ self.origin_input_ids_unpadded
168
+ )
169
+
170
+ all_text = (
171
+ self.origin_input_text
172
+ + self.prev_output_str
173
+ + cur_output_str
174
+ + jump_forward_str
175
+ )
176
+ all_ids = self.tokenizer.encode(all_text)
177
+ prompt_tokens = len(self.origin_input_ids_unpadded)
178
+ self.origin_input_ids = all_ids[:prompt_tokens]
179
+ self.origin_input_ids_unpadded = self.origin_input_ids
180
+ # NOTE: the output ids may not strictly correspond to the output text
181
+ old_prev_output_ids = self.prev_output_ids
182
+ self.prev_output_ids = all_ids[prompt_tokens:]
183
+ self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
184
+ self.output_ids = []
185
+
186
+ self.regex_fsm_state = next_state
187
+
188
+ if self.return_logprob:
189
+ # For fast-forward part's logprobs
190
+ k = 0
191
+ for i, old_id in enumerate(old_prev_output_ids):
192
+ if old_id == self.prev_output_ids[i]:
193
+ k = k + 1
194
+ else:
195
+ break
196
+ self.decode_token_logprobs = self.decode_token_logprobs[:k]
197
+ self.decode_top_logprobs = self.decode_top_logprobs[:k]
198
+ self.logprob_start_len = prompt_tokens + k
199
+ self.last_update_decode_tokens = len(self.prev_output_ids) - k
200
+
201
+ # print("=" * 100)
202
+ # print(f"Catch jump forward:\n{jump_forward_str}")
203
+ # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
204
+ # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
205
+
206
+ # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
207
+ # print("*" * 100)
208
+
209
+ def __repr__(self):
210
+ return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
211
+
212
+
213
+ @dataclass
214
+ class Batch:
215
+ reqs: List[Req]
216
+ req_to_token_pool: ReqToTokenPool
217
+ token_to_kv_pool: TokenToKVPool
218
+ tree_cache: RadixCache
219
+
220
+ # batched arguments to model runner
221
+ input_ids: torch.Tensor = None
222
+ req_pool_indices: torch.Tensor = None
223
+ seq_lens: torch.Tensor = None
224
+ prefix_lens: torch.Tensor = None
225
+ position_ids_offsets: torch.Tensor = None
226
+ out_cache_loc: torch.Tensor = None
227
+ out_cache_cont_start: torch.Tensor = None
228
+ out_cache_cont_end: torch.Tensor = None
229
+
230
+ # for processing logprobs
231
+ return_logprob: bool = False
232
+ top_logprobs_nums: List[int] = None
233
+
234
+ # for multimodal
235
+ pixel_values: List[torch.Tensor] = None
236
+ image_sizes: List[List[int]] = None
237
+ image_offsets: List[int] = None
238
+
239
+ # other arguments for control
240
+ output_ids: torch.Tensor = None
241
+ extend_num_tokens: int = None
242
+
243
+ # batched sampling params
244
+ temperatures: torch.Tensor = None
245
+ top_ps: torch.Tensor = None
246
+ top_ks: torch.Tensor = None
247
+ frequency_penalties: torch.Tensor = None
248
+ presence_penalties: torch.Tensor = None
249
+ logit_bias: torch.Tensor = None
250
+
251
+ @classmethod
252
+ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
253
+ return_logprob = any(req.return_logprob for req in reqs)
254
+
255
+ return cls(
256
+ reqs=reqs,
257
+ req_to_token_pool=req_to_token_pool,
258
+ token_to_kv_pool=token_to_kv_pool,
259
+ tree_cache=tree_cache,
260
+ return_logprob=return_logprob,
261
+ )
262
+
263
+ def is_empty(self):
264
+ return len(self.reqs) == 0
265
+
266
+ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
267
+ device = "cuda"
268
+ bs = len(self.reqs)
269
+ reqs = self.reqs
270
+ input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
271
+ prefix_indices = [r.prefix_indices for r in reqs]
272
+
273
+ # Handle prefix
274
+ flatten_input_ids = []
275
+ extend_lens = []
276
+ prefix_lens = []
277
+ seq_lens = []
278
+
279
+ req_pool_indices = self.req_to_token_pool.alloc(bs)
280
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
281
+ for i in range(bs):
282
+ flatten_input_ids.extend(input_ids[i])
283
+ extend_lens.append(len(input_ids[i]))
284
+
285
+ if len(prefix_indices[i]) == 0:
286
+ prefix_lens.append(0)
287
+ else:
288
+ prefix_lens.append(len(prefix_indices[i]))
289
+ self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
290
+ : len(prefix_indices[i])
291
+ ] = prefix_indices[i]
292
+
293
+ seq_lens.append(prefix_lens[-1] + extend_lens[-1])
294
+
295
+ position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
296
+
297
+ # Alloc mem
298
+ seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
299
+ extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
300
+ out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
301
+ if out_cache_loc is None:
302
+ self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
303
+ out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
304
+
305
+ if out_cache_loc is None:
306
+ print("Prefill out of memory. This should never happen.")
307
+ self.tree_cache.pretty_print()
308
+ exit()
309
+
310
+ pt = 0
311
+ for i in range(bs):
312
+ self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
313
+ prefix_lens[i] : prefix_lens[i] + extend_lens[i]
314
+ ] = out_cache_loc[pt : pt + extend_lens[i]]
315
+ pt += extend_lens[i]
316
+
317
+ # Handle logit bias but only allocate when needed
318
+ logit_bias = None
319
+ for i in range(bs):
320
+ if reqs[i].sampling_params.dtype == "int":
321
+ if logit_bias is None:
322
+ logit_bias = torch.zeros(
323
+ (bs, vocab_size), dtype=torch.float32, device=device
324
+ )
325
+ logit_bias[i] = int_token_logit_bias
326
+
327
+ # Set fields
328
+ self.input_ids = torch.tensor(
329
+ flatten_input_ids, dtype=torch.int32, device=device
330
+ )
331
+ self.pixel_values = [r.pixel_values for r in reqs]
332
+ self.image_sizes = [r.image_size for r in reqs]
333
+ self.image_offsets = [
334
+ r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
335
+ ]
336
+ self.req_pool_indices = req_pool_indices
337
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
338
+ self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
339
+ self.position_ids_offsets = position_ids_offsets
340
+ self.extend_num_tokens = extend_num_tokens
341
+ self.out_cache_loc = out_cache_loc
342
+ self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
343
+
344
+ self.temperatures = torch.tensor(
345
+ [r.sampling_params.temperature for r in reqs],
346
+ dtype=torch.float,
347
+ device=device,
348
+ ).view(-1, 1)
349
+ self.top_ps = torch.tensor(
350
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
351
+ ).view(-1, 1)
352
+ self.top_ks = torch.tensor(
353
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
354
+ ).view(-1, 1)
355
+ self.frequency_penalties = torch.tensor(
356
+ [r.sampling_params.frequency_penalty for r in reqs],
357
+ dtype=torch.float,
358
+ device=device,
359
+ )
360
+ self.presence_penalties = torch.tensor(
361
+ [r.sampling_params.presence_penalty for r in reqs],
362
+ dtype=torch.float,
363
+ device=device,
364
+ )
365
+ self.logit_bias = logit_bias
366
+
367
+ def check_decode_mem(self):
368
+ bs = len(self.reqs)
369
+ if self.token_to_kv_pool.available_size() >= bs:
370
+ return True
371
+
372
+ self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
373
+
374
+ if self.token_to_kv_pool.available_size() >= bs:
375
+ return True
376
+
377
+ return False
378
+
379
+ def retract_decode(self):
380
+ sorted_indices = [i for i in range(len(self.reqs))]
381
+ # TODO(lsyin): improve the priority of retraction
382
+ sorted_indices.sort(
383
+ key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
384
+ reverse=True,
385
+ )
386
+
387
+ retracted_reqs = []
388
+ seq_lens_cpu = self.seq_lens.cpu().numpy()
389
+ req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
390
+ while self.token_to_kv_pool.available_size() < len(self.reqs):
391
+ idx = sorted_indices.pop()
392
+ req = self.reqs[idx]
393
+ retracted_reqs.append(req)
394
+
395
+ # TODO: apply more fine-grained retraction
396
+ last_uncached_pos = len(req.prefix_indices)
397
+ token_indices = self.req_to_token_pool.req_to_token[
398
+ req_pool_indices_cpu[idx]
399
+ ][last_uncached_pos : seq_lens_cpu[idx]]
400
+ self.token_to_kv_pool.dec_refs(token_indices)
401
+
402
+ # release the last node
403
+ self.tree_cache.dec_lock_ref(req.last_node)
404
+
405
+ cur_output_str = req.partial_decode(req.output_ids)
406
+ req.prev_output_str = req.prev_output_str + cur_output_str
407
+ req.prev_output_ids.extend(req.output_ids)
408
+
409
+ req.prefix_indices = None
410
+ req.last_node = None
411
+ req.extend_input_len = 0
412
+ req.output_ids = []
413
+
414
+ # For incremental logprobs
415
+ req.last_update_decode_tokens = 0
416
+ req.logprob_start_len = 10**9
417
+
418
+ self.filter_batch(sorted_indices)
419
+
420
+ return retracted_reqs
421
+
422
+ def check_for_jump_forward(self, model_runner):
423
+ jump_forward_reqs = []
424
+ filter_indices = [i for i in range(len(self.reqs))]
425
+
426
+ req_pool_indices_cpu = None
427
+
428
+ for i, req in enumerate(self.reqs):
429
+ if req.jump_forward_map is not None:
430
+ res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
431
+ if res is not None:
432
+ jump_forward_str, next_state = res
433
+ if len(jump_forward_str) <= 1:
434
+ continue
435
+
436
+ if req_pool_indices_cpu is None:
437
+ req_pool_indices_cpu = self.req_pool_indices.tolist()
438
+
439
+ # insert the old request into tree_cache
440
+ self.tree_cache.cache_req(
441
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
442
+ last_uncached_pos=len(req.prefix_indices),
443
+ req_pool_idx=req_pool_indices_cpu[i],
444
+ )
445
+
446
+ # unlock the last node
447
+ self.tree_cache.dec_lock_ref(req.last_node)
448
+
449
+ # jump-forward
450
+ req.jump_forward_and_retokenize(jump_forward_str, next_state)
451
+
452
+ # re-applying image padding
453
+ if req.pixel_values is not None:
454
+ (
455
+ req.origin_input_ids,
456
+ req.image_offset,
457
+ ) = model_runner.model.pad_input_ids(
458
+ req.origin_input_ids_unpadded,
459
+ req.pad_value,
460
+ req.pixel_values.shape,
461
+ req.image_size,
462
+ )
463
+
464
+ jump_forward_reqs.append(req)
465
+ filter_indices.remove(i)
466
+
467
+ if len(filter_indices) < len(self.reqs):
468
+ self.filter_batch(filter_indices)
469
+
470
+ return jump_forward_reqs
471
+
472
+ def prepare_for_decode(self, input_ids=None):
473
+ if input_ids is None:
474
+ input_ids = [
475
+ r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
476
+ ]
477
+ self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
478
+ self.seq_lens.add_(1)
479
+ self.prefix_lens = None
480
+
481
+ # Alloc mem
482
+ bs = len(self.reqs)
483
+ alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
484
+ if alloc_res is None:
485
+ self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
486
+
487
+ if self.out_cache_loc is None:
488
+ print("Decode out of memory. This should never happen.")
489
+ self.tree_cache.pretty_print()
490
+ exit()
491
+
492
+ self.out_cache_cont_start = None
493
+ self.out_cache_cont_end = None
494
+ else:
495
+ self.out_cache_loc = alloc_res[0]
496
+ self.out_cache_cont_start = alloc_res[1]
497
+ self.out_cache_cont_end = alloc_res[2]
498
+
499
+ self.req_to_token_pool.req_to_token[
500
+ self.req_pool_indices, self.seq_lens - 1
501
+ ] = self.out_cache_loc
502
+
503
+ def filter_batch(self, unfinished_indices: List[int]):
504
+ self.reqs = [self.reqs[i] for i in unfinished_indices]
505
+ new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
506
+ self.seq_lens = self.seq_lens[new_indices]
507
+ self.input_ids = None
508
+ self.req_pool_indices = self.req_pool_indices[new_indices]
509
+ self.prefix_lens = None
510
+ self.position_ids_offsets = self.position_ids_offsets[new_indices]
511
+ self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
512
+ self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
513
+ self.return_logprob = any(req.return_logprob for req in self.reqs)
514
+
515
+ for item in [
516
+ "temperatures",
517
+ "top_ps",
518
+ "top_ks",
519
+ "frequency_penalties",
520
+ "presence_penalties",
521
+ "logit_bias",
522
+ ]:
523
+ self_val = getattr(self, item, None)
524
+ # logit_bias can be None
525
+ if self_val is not None:
526
+ setattr(self, item, self_val[new_indices])
527
+
528
+ def merge(self, other: "Batch"):
529
+ self.reqs.extend(other.reqs)
530
+
531
+ self.req_pool_indices = torch.concat(
532
+ [self.req_pool_indices, other.req_pool_indices]
533
+ )
534
+ self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
535
+ self.prefix_lens = None
536
+ self.position_ids_offsets = torch.concat(
537
+ [self.position_ids_offsets, other.position_ids_offsets]
538
+ )
539
+ self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
540
+ self.top_logprobs_nums.extend(other.top_logprobs_nums)
541
+ self.return_logprob = any(req.return_logprob for req in self.reqs)
542
+
543
+ for item in [
544
+ "temperatures",
545
+ "top_ps",
546
+ "top_ks",
547
+ "frequency_penalties",
548
+ "presence_penalties",
549
+ ]:
550
+ self_val = getattr(self, item, None)
551
+ other_val = getattr(other, item, None)
552
+ setattr(self, item, torch.concat([self_val, other_val]))
553
+
554
+ # logit_bias can be None
555
+ if self.logit_bias is not None or other.logit_bias is not None:
556
+ vocab_size = (
557
+ self.logit_bias.shape[1]
558
+ if self.logit_bias is not None
559
+ else other.logit_bias.shape[1]
560
+ )
561
+ if self.logit_bias is None:
562
+ self.logit_bias = torch.zeros(
563
+ (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
564
+ )
565
+ if other.logit_bias is None:
566
+ other.logit_bias = torch.zeros(
567
+ (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
568
+ )
569
+ self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
570
+
571
+ def sample(self, logits: torch.Tensor):
572
+ # Post process logits
573
+ logits = logits.contiguous()
574
+ logits.div_(self.temperatures)
575
+ if self.logit_bias is not None:
576
+ logits.add_(self.logit_bias)
577
+
578
+ has_regex = any(req.regex_fsm is not None for req in self.reqs)
579
+ if has_regex:
580
+ allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
581
+ for i, req in enumerate(self.reqs):
582
+ if req.regex_fsm is not None:
583
+ allowed_mask.zero_()
584
+ allowed_mask[
585
+ req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
586
+ ] = 1
587
+ logits[i].masked_fill_(~allowed_mask, float("-inf"))
588
+
589
+ # TODO(lmzheng): apply penalty
590
+ probs = torch.softmax(logits, dim=-1)
591
+ probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
592
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
593
+ batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
594
+ -1
595
+ )
596
+ batch_next_token_probs = torch.gather(
597
+ probs_sort, dim=1, index=sampled_index
598
+ ).view(-1)
599
+
600
+ if has_regex:
601
+ batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
602
+ for i, req in enumerate(self.reqs):
603
+ if req.regex_fsm is not None:
604
+ req.regex_fsm_state = req.regex_fsm.next_state(
605
+ req.regex_fsm_state, batch_next_token_ids_cpu[i]
606
+ )
607
+
608
+ return batch_next_token_ids, batch_next_token_probs
609
+
610
+
611
+ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
612
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
613
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
614
+ probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
615
+ probs_sort[
616
+ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
617
+ ] = 0.0
618
+ probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
619
+ return probs_sort, probs_idx