sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +976 -0
  7. sglang/check_env.py +171 -0
  8. sglang/global_config.py +3 -2
  9. sglang/lang/backend/__init__.py +0 -0
  10. sglang/lang/backend/anthropic.py +77 -0
  11. sglang/lang/backend/base_backend.py +80 -0
  12. sglang/lang/backend/litellm.py +90 -0
  13. sglang/lang/backend/openai.py +438 -0
  14. sglang/lang/backend/runtime_endpoint.py +283 -0
  15. sglang/lang/backend/vertexai.py +149 -0
  16. sglang/lang/interpreter.py +1 -0
  17. sglang/lang/tracer.py +1 -1
  18. sglang/launch_server.py +1 -1
  19. sglang/launch_server_llavavid.py +1 -4
  20. sglang/srt/conversation.py +1 -1
  21. sglang/srt/hf_transformers_utils.py +13 -1
  22. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  23. sglang/srt/layers/extend_attention.py +0 -39
  24. sglang/srt/layers/linear.py +869 -0
  25. sglang/srt/layers/logits_processor.py +4 -5
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +39 -24
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
  31. sglang/srt/managers/controller/infer_batch.py +90 -63
  32. sglang/srt/managers/controller/manager_multi.py +107 -100
  33. sglang/srt/managers/controller/manager_single.py +76 -96
  34. sglang/srt/managers/controller/model_runner.py +41 -26
  35. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  36. sglang/srt/managers/controller/tp_worker.py +136 -149
  37. sglang/srt/managers/detokenizer_manager.py +49 -5
  38. sglang/srt/managers/io_struct.py +36 -17
  39. sglang/srt/managers/tokenizer_manager.py +228 -125
  40. sglang/srt/memory_pool.py +32 -11
  41. sglang/srt/model_loader/model_loader.py +277 -0
  42. sglang/srt/model_loader/utils.py +260 -0
  43. sglang/srt/models/chatglm.py +1 -0
  44. sglang/srt/models/dbrx.py +1 -0
  45. sglang/srt/models/deepseek.py +430 -0
  46. sglang/srt/models/gpt_bigcode.py +282 -0
  47. sglang/srt/models/grok.py +1 -0
  48. sglang/srt/models/internlm2.py +317 -0
  49. sglang/srt/models/llama2.py +81 -23
  50. sglang/srt/models/llama_classification.py +1 -0
  51. sglang/srt/models/llava.py +1 -0
  52. sglang/srt/models/llavavid.py +1 -0
  53. sglang/srt/models/minicpm.py +1 -0
  54. sglang/srt/models/mixtral.py +1 -0
  55. sglang/srt/models/mixtral_quant.py +1 -0
  56. sglang/srt/models/qwen.py +1 -0
  57. sglang/srt/models/qwen2.py +6 -0
  58. sglang/srt/models/qwen2_moe.py +7 -4
  59. sglang/srt/models/stablelm.py +1 -0
  60. sglang/srt/openai_api/adapter.py +432 -0
  61. sglang/srt/openai_api/api_adapter.py +432 -0
  62. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  63. sglang/srt/openai_api/openai_protocol.py +207 -0
  64. sglang/srt/openai_api/protocol.py +208 -0
  65. sglang/srt/openai_protocol.py +17 -0
  66. sglang/srt/sampling_params.py +2 -0
  67. sglang/srt/server.py +132 -84
  68. sglang/srt/server_args.py +35 -21
  69. sglang/srt/utils.py +65 -117
  70. sglang/test/test_conversation.py +1 -1
  71. sglang/test/test_openai_protocol.py +1 -1
  72. sglang/test/test_programs.py +1 -1
  73. sglang/test/test_utils.py +2 -2
  74. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
  75. sglang-0.1.24.dist-info/RECORD +105 -0
  76. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
  77. sglang-0.1.21.dist-info/RECORD +0 -82
  78. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
  79. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,7 @@ class TokenizerManager:
61
61
  self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
62
62
 
63
63
  self.send_to_router = context.socket(zmq.PUSH)
64
- self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
64
+ self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
65
65
 
66
66
  self.model_path = server_args.model_path
67
67
  self.hf_config = get_config(
@@ -69,7 +69,10 @@ class TokenizerManager:
69
69
  trust_remote_code=server_args.trust_remote_code,
70
70
  model_overide_args=model_overide_args,
71
71
  )
72
- self.context_len = get_context_length(self.hf_config)
72
+ if server_args.context_length is not None:
73
+ self.context_len = server_args.context_length
74
+ else:
75
+ self.context_len = get_context_length(self.hf_config)
73
76
 
74
77
  if is_multimodal_model(self.model_path):
75
78
  self.processor = get_processor(
@@ -119,125 +122,150 @@ class TokenizerManager:
119
122
 
120
123
  obj.post_init()
121
124
  is_single = obj.is_single
122
- if is_single:
123
- rid = obj.rid
124
-
125
- if obj.input_ids is None:
126
- input_ids = self.tokenizer.encode(obj.text)
127
- else:
128
- input_ids = obj.input_ids
129
125
 
130
- if len(input_ids) >= self.context_len:
131
- raise ValueError(
132
- f"The input ({len(input_ids)} tokens) is longer than the "
133
- f"model's context length ({self.context_len} tokens)."
134
- )
126
+ if is_single:
127
+ async for response in self._handle_single_request(obj, request):
128
+ yield response
129
+ else:
130
+ if obj.stream:
131
+ raise ValueError("Do not support stream for batch mode.")
135
132
 
136
- sampling_params = SamplingParams(**obj.sampling_params)
137
- if sampling_params.max_new_tokens != 0:
138
- sampling_params.normalize(self.tokenizer)
139
- sampling_params.verify()
133
+ async for response in self._handle_batch_request(obj, request):
134
+ yield response
140
135
 
141
- if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
142
- pixel_values, image_hash, image_size = await self.get_pixel_values(
143
- obj.image_data[0]
144
- )
145
- elif isinstance(obj.image_data, str):
146
- pixel_values, image_hash, image_size = await self.get_pixel_values(
147
- obj.image_data
148
- )
136
+ async def _handle_single_request(self, obj, request, index=None, is_prefill=False):
137
+ if is_prefill:
138
+ if isinstance(obj.text, list):
139
+ input_text = obj.text[index]
140
+ rid = obj.rid[index]
149
141
  else:
150
- pixel_values, image_hash, image_size = None, None, None
151
- tokenized_obj = TokenizedGenerateReqInput(
152
- rid=rid,
153
- input_text=obj.text,
154
- input_ids=input_ids,
155
- pixel_values=pixel_values,
156
- image_hash=image_hash,
157
- image_size=image_size,
158
- sampling_params=sampling_params,
159
- return_logprob=obj.return_logprob,
160
- logprob_start_len=obj.logprob_start_len,
161
- top_logprobs_num=obj.top_logprobs_num,
162
- stream=obj.stream,
142
+ input_text = obj.text
143
+ rid = obj.rid[0]
144
+ input_ids = self.tokenizer.encode(input_text)
145
+ sampling_params = SamplingParams(**obj.sampling_params[0])
146
+ sampling_params.max_new_tokens = 0
147
+ pixel_values, image_hash, image_size = await self._get_pixel_values(
148
+ obj.image_data[0]
149
+ )
150
+ return_logprob = obj.return_logprob[0]
151
+ logprob_start_len = obj.logprob_start_len[0]
152
+ top_logprobs_num = obj.top_logprobs_num[0]
153
+ else:
154
+ rid = obj.rid if index is None else obj.rid[index]
155
+ input_text = obj.text if index is None else obj.text[index]
156
+ input_ids = (
157
+ self.tokenizer.encode(input_text)
158
+ if obj.input_ids is None
159
+ else obj.input_ids
163
160
  )
164
- self.send_to_router.send_pyobj(tokenized_obj)
161
+ if index is not None and obj.input_ids:
162
+ input_ids = obj.input_ids[index]
165
163
 
166
- event = asyncio.Event()
167
- state = ReqState([], False, event)
168
- self.rid_to_state[rid] = state
164
+ self._validate_input_length(input_ids)
165
+ sampling_params = self._get_sampling_params(
166
+ obj.sampling_params if index is None else obj.sampling_params[index]
167
+ )
168
+ pixel_values, image_hash, image_size = await self._get_pixel_values(
169
+ obj.image_data if index is None else obj.image_data[index]
170
+ )
171
+ return_logprob = (
172
+ obj.return_logprob if index is None else obj.return_logprob[index]
173
+ )
174
+ logprob_start_len = (
175
+ obj.logprob_start_len if index is None else obj.logprob_start_len[index]
176
+ )
177
+ top_logprobs_num = (
178
+ obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
179
+ )
169
180
 
170
- while True:
171
- try:
172
- await asyncio.wait_for(event.wait(), timeout=4)
173
- except asyncio.TimeoutError:
174
- if request is not None and await request.is_disconnected():
175
- self.abort_request(rid)
176
- raise ValueError(f"Abort request {rid}")
181
+ tokenized_obj = TokenizedGenerateReqInput(
182
+ rid,
183
+ input_text,
184
+ input_ids,
185
+ pixel_values,
186
+ image_hash,
187
+ image_size,
188
+ sampling_params,
189
+ return_logprob,
190
+ logprob_start_len,
191
+ top_logprobs_num,
192
+ obj.stream,
193
+ )
194
+ self.send_to_router.send_pyobj(tokenized_obj)
195
+
196
+ event = asyncio.Event()
197
+ state = ReqState([], False, event)
198
+ self.rid_to_state[rid] = state
199
+ if is_prefill:
200
+ await self._wait_for_prefill_response(event, state, obj, request, rid)
201
+ yield input_ids
202
+ else:
203
+ async for response in self._wait_for_response(
204
+ event, state, obj, rid, request
205
+ ):
206
+ yield response
207
+
208
+ async def _handle_batch_request(self, obj, request):
209
+ batch_size = obj.batch_size
210
+ parallel_sample_num = obj.sampling_params[0].get("n", 1)
211
+
212
+ if parallel_sample_num != 1:
213
+ ## send prefill requests
214
+ parallel_sample_num += 1
215
+ input_id_result = [] if obj.input_ids is None else None
216
+ for i in range(batch_size):
217
+ async for input_id in self._handle_single_request(
218
+ obj, request, index=i, is_prefill=True
219
+ ):
220
+ if input_id_result is not None:
221
+ input_id_result.append(input_id)
222
+ pass
223
+ if len(input_id_result) > 1 and input_id_result is not None:
224
+ obj.input_ids = input_id_result
225
+ elif input_id_result is not None:
226
+ obj.input_ids = input_id_result[0]
227
+ # First send out all requests
228
+ for i in range(batch_size):
229
+ for j in range(parallel_sample_num):
230
+ if j == 0 and parallel_sample_num != 1:
177
231
  continue
178
-
179
- out = self.convert_logprob_style(
180
- state.out_list[-1],
181
- obj.return_logprob,
182
- obj.top_logprobs_num,
183
- obj.return_text_in_logprobs,
232
+ index = i * parallel_sample_num + j
233
+ if parallel_sample_num != 1:
234
+ # Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
235
+ index += batch_size - 1 - i
236
+ rid = obj.rid[index]
237
+ if parallel_sample_num == 1:
238
+ ## select operation
239
+ if obj.input_ids is None:
240
+ input_text = obj.text[i]
241
+ input_ids = self.tokenizer.encode(obj.text[i])
242
+ else:
243
+ input_text = None
244
+ input_ids = obj.input_ids[i]
245
+ else:
246
+ if batch_size == 1:
247
+ input_text = obj.text
248
+ input_ids = obj.input_ids
249
+ else:
250
+ input_text = obj.text[i]
251
+ input_ids = obj.input_ids[i]
252
+ sampling_params = self._get_sampling_params(obj.sampling_params[index])
253
+ pixel_values, image_hash, image_size = await self._get_pixel_values(
254
+ obj.image_data[index]
184
255
  )
185
256
 
186
- if self.server_args.log_requests and state.finished:
187
- logger.info(f"in={obj.text}, out={out}")
188
-
189
- state.out_list = []
190
- if state.finished:
191
- del self.rid_to_state[rid]
192
-
193
- yield out
194
-
195
- break
196
-
197
- event.clear()
198
-
199
- yield out
200
- else:
201
- if obj.stream:
202
- raise ValueError("Do not support stream for batch mode.")
203
-
204
- if obj.input_ids is None:
205
- bs = len(obj.text)
206
- else:
207
- bs = len(obj.input_ids)
208
-
209
- for i in range(bs):
210
- rid = obj.rid[i]
211
-
212
- if obj.input_ids is None:
213
- input_text = obj.text[i]
214
- input_ids = self.tokenizer.encode(obj.text[i])
215
- else:
216
- input_text = None
217
- input_ids = obj.input_ids[i]
218
-
219
- sampling_params = SamplingParams(**obj.sampling_params[i])
220
- if sampling_params.max_new_tokens != 0:
221
- sampling_params.normalize(self.tokenizer)
222
- sampling_params.verify()
223
- if obj.image_data[i] is None:
224
- pixel_values, image_hash, image_size = None, None, None
225
- else:
226
- pixel_values, image_hash, image_size = await self.get_pixel_values(
227
- obj.image_data[i]
228
- )
229
257
  tokenized_obj = TokenizedGenerateReqInput(
230
- rid=rid,
231
- input_text=input_text,
232
- input_ids=input_ids,
233
- pixel_values=pixel_values,
234
- image_hash=image_hash,
235
- image_size=image_size,
236
- sampling_params=sampling_params,
237
- return_logprob=obj.return_logprob[i],
238
- logprob_start_len=obj.logprob_start_len[i],
239
- top_logprobs_num=obj.top_logprobs_num[i],
240
- stream=obj.stream,
258
+ rid,
259
+ input_text,
260
+ input_ids,
261
+ pixel_values,
262
+ image_hash,
263
+ image_size,
264
+ sampling_params,
265
+ obj.return_logprob[index],
266
+ obj.logprob_start_len[index],
267
+ obj.top_logprobs_num[index],
268
+ obj.stream,
241
269
  )
242
270
  self.send_to_router.send_pyobj(tokenized_obj)
243
271
 
@@ -245,9 +273,16 @@ class TokenizerManager:
245
273
  state = ReqState([], False, event)
246
274
  self.rid_to_state[rid] = state
247
275
 
248
- output_list = []
249
- for i in range(bs):
250
- rid = obj.rid[i]
276
+ # Then wait for all responses
277
+ output_list = []
278
+ for i in range(batch_size):
279
+ for j in range(parallel_sample_num):
280
+ if j == 0 and parallel_sample_num != 1:
281
+ continue
282
+ index = i * parallel_sample_num + j
283
+ if parallel_sample_num != 1:
284
+ index += batch_size - 1 - i
285
+ rid = obj.rid[index]
251
286
  state = self.rid_to_state[rid]
252
287
 
253
288
  while True:
@@ -260,19 +295,86 @@ class TokenizerManager:
260
295
  self.abort_request(rid)
261
296
  raise ValueError(f"Abort request {rid}")
262
297
  continue
263
-
264
298
  output_list.append(
265
299
  self.convert_logprob_style(
266
300
  state.out_list[-1],
267
- obj.return_logprob[i],
268
- obj.top_logprobs_num[i],
301
+ obj.return_logprob[index],
302
+ obj.top_logprobs_num[index],
269
303
  obj.return_text_in_logprobs,
270
304
  )
271
305
  )
272
306
  assert state.finished
273
307
  del self.rid_to_state[rid]
274
308
 
275
- yield output_list
309
+ yield output_list
310
+
311
+ def _validate_input_length(self, input_ids):
312
+ if len(input_ids) >= self.context_len:
313
+ raise ValueError(
314
+ f"The input ({len(input_ids)} tokens) is longer than the "
315
+ f"model's context length ({self.context_len} tokens)."
316
+ )
317
+
318
+ def _get_sampling_params(self, sampling_params_data, max_new_tokens=None):
319
+ sampling_params = SamplingParams(**sampling_params_data)
320
+ if max_new_tokens is not None:
321
+ sampling_params.max_new_tokens = max_new_tokens
322
+ if sampling_params.max_new_tokens != 0:
323
+ sampling_params.normalize(self.tokenizer)
324
+ sampling_params.verify()
325
+ return sampling_params
326
+
327
+ async def _get_pixel_values(self, image_data):
328
+ if isinstance(image_data, list) and len(image_data) > 0:
329
+ return await self.get_pixel_values(image_data[0])
330
+ elif isinstance(image_data, str):
331
+ return await self.get_pixel_values(image_data)
332
+ else:
333
+ return None, None, None
334
+
335
+ async def _wait_for_response(self, event, state, obj, rid, request):
336
+ while True:
337
+ try:
338
+ await asyncio.wait_for(event.wait(), timeout=4)
339
+ except asyncio.TimeoutError:
340
+ if request is not None and await request.is_disconnected():
341
+ self.abort_request(rid)
342
+ raise ValueError(f"Abort request {rid}")
343
+ continue
344
+
345
+ out = self.convert_logprob_style(
346
+ state.out_list[-1],
347
+ obj.return_logprob,
348
+ obj.top_logprobs_num,
349
+ obj.return_text_in_logprobs,
350
+ )
351
+
352
+ if self.server_args.log_requests and state.finished:
353
+ logger.info(f"in={obj.text}, out={out}")
354
+
355
+ state.out_list = []
356
+ if state.finished:
357
+ del self.rid_to_state[rid]
358
+ yield out
359
+ break
360
+
361
+ event.clear()
362
+ yield out
363
+
364
+ async def _wait_for_prefill_response(self, event, state, obj, request, rid):
365
+ while True:
366
+ try:
367
+ await asyncio.wait_for(state.event.wait(), timeout=4)
368
+ break
369
+ except asyncio.TimeoutError:
370
+ if request is not None and await request.is_disconnected():
371
+ for rid in obj.rid:
372
+ self.abort_request(rid)
373
+ raise ValueError(f"Abort request {rid}")
374
+ continue
375
+
376
+ assert state.finished
377
+ del self.rid_to_state[rid]
276
378
 
277
379
  def flush_cache(self):
278
380
  req = FlushCacheReq()
@@ -335,15 +437,16 @@ class TokenizerManager:
335
437
  )
336
438
 
337
439
  if top_logprobs_num > 0:
338
- ret["meta_info"][
339
- "prefill_top_logprobs"
340
- ] = self.detokenize_top_logprobs_tokens(
341
- ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
440
+ ret["meta_info"]["prefill_top_logprobs"] = (
441
+ self.detokenize_top_logprobs_tokens(
442
+ ret["meta_info"]["prefill_top_logprobs"],
443
+ return_text_in_logprobs,
444
+ )
342
445
  )
343
- ret["meta_info"][
344
- "decode_top_logprobs"
345
- ] = self.detokenize_top_logprobs_tokens(
346
- ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
446
+ ret["meta_info"]["decode_top_logprobs"] = (
447
+ self.detokenize_top_logprobs_tokens(
448
+ ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
449
+ )
347
450
  )
348
451
  return ret
349
452
 
sglang/srt/memory_pool.py CHANGED
@@ -11,6 +11,7 @@ class ReqToTokenPool:
11
11
  """A memory pool that maps a request to its token locations."""
12
12
 
13
13
  def __init__(self, size: int, max_context_len: int):
14
+ self.size = size
14
15
  self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
15
16
  self.req_to_token = torch.empty(
16
17
  (size, max_context_len), dtype=torch.int32, device="cuda"
@@ -21,7 +22,9 @@ class ReqToTokenPool:
21
22
  if need_size > self.can_use_mem_size:
22
23
  return None
23
24
 
24
- select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
25
+ select_index = (
26
+ torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
27
+ )
25
28
  self.mem_state[select_index] = False
26
29
  self.can_use_mem_size -= need_size
27
30
 
@@ -42,15 +45,26 @@ class ReqToTokenPool:
42
45
  class TokenToKVPool:
43
46
  """A memory pool that maps a token to its kv cache locations"""
44
47
 
45
- def __init__(self, size, dtype, head_num, head_dim, layer_num):
48
+ def __init__(
49
+ self,
50
+ size: int,
51
+ dtype: torch.dtype,
52
+ head_num: int,
53
+ head_dim: int,
54
+ layer_num: int,
55
+ ):
46
56
  self.size = size
47
57
 
48
58
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
49
59
  self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
50
60
 
51
- # [size, key/value, head_num, head_dim] for each layer
52
- self.kv_data = [
53
- torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
61
+ # [size, head_num, head_dim] for each layer
62
+ self.k_buffer = [
63
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
64
+ for _ in range(layer_num)
65
+ ]
66
+ self.v_buffer = [
67
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
54
68
  for _ in range(layer_num)
55
69
  ]
56
70
 
@@ -61,16 +75,19 @@ class TokenToKVPool:
61
75
  self.can_use_mem_size = self.size
62
76
  self.clear()
63
77
 
64
- def get_key_buffer(self, layer_id):
65
- return self.kv_data[layer_id][:, 0]
78
+ def get_key_buffer(self, layer_id: int):
79
+ return self.k_buffer[layer_id]
80
+
81
+ def get_value_buffer(self, layer_id: int):
82
+ return self.v_buffer[layer_id]
66
83
 
67
- def get_value_buffer(self, layer_id):
68
- return self.kv_data[layer_id][:, 1]
84
+ def get_kv_buffer(self, layer_id: int):
85
+ return self.k_buffer[layer_id], self.v_buffer[layer_id]
69
86
 
70
87
  def available_size(self):
71
88
  return self.can_use_mem_size + len(self.prefetch_buffer)
72
89
 
73
- def alloc(self, need_size):
90
+ def alloc(self, need_size: int):
74
91
  buffer_len = len(self.prefetch_buffer)
75
92
  if need_size <= buffer_len:
76
93
  select_index = self.prefetch_buffer[:need_size]
@@ -79,7 +96,9 @@ class TokenToKVPool:
79
96
 
80
97
  addition_size = need_size - buffer_len
81
98
  alloc_size = max(addition_size, self.prefetch_chunk_size)
82
- select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
99
+ select_index = (
100
+ torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
101
+ )
83
102
 
84
103
  if select_index.shape[0] < addition_size:
85
104
  return None
@@ -98,6 +117,8 @@ class TokenToKVPool:
98
117
  self.can_use_mem_size += len(free_index)
99
118
 
100
119
  def clear(self):
120
+ self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
121
+
101
122
  self.mem_state.fill_(True)
102
123
  self.can_use_mem_size = self.size
103
124