sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +758 -0
- sglang/check_env.py +171 -0
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +31 -5
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
- sglang/srt/managers/controller/infer_batch.py +47 -49
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +35 -23
- sglang/srt/managers/controller/tp_worker.py +127 -138
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +19 -6
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +65 -16
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +113 -84
- sglang/srt/server_args.py +23 -15
- sglang/srt/utils.py +16 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
- sglang-0.1.22.dist-info/RECORD +103 -0
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,7 @@ class TokenizerManager:
|
|
61
61
|
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
62
62
|
|
63
63
|
self.send_to_router = context.socket(zmq.PUSH)
|
64
|
-
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.
|
64
|
+
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
65
65
|
|
66
66
|
self.model_path = server_args.model_path
|
67
67
|
self.hf_config = get_config(
|
@@ -69,7 +69,10 @@ class TokenizerManager:
|
|
69
69
|
trust_remote_code=server_args.trust_remote_code,
|
70
70
|
model_overide_args=model_overide_args,
|
71
71
|
)
|
72
|
-
|
72
|
+
if server_args.context_length is not None:
|
73
|
+
self.context_len = server_args.context_length
|
74
|
+
else:
|
75
|
+
self.context_len = get_context_length(self.hf_config)
|
73
76
|
|
74
77
|
if is_multimodal_model(self.model_path):
|
75
78
|
self.processor = get_processor(
|
@@ -119,125 +122,150 @@ class TokenizerManager:
|
|
119
122
|
|
120
123
|
obj.post_init()
|
121
124
|
is_single = obj.is_single
|
122
|
-
if is_single:
|
123
|
-
rid = obj.rid
|
124
|
-
|
125
|
-
if obj.input_ids is None:
|
126
|
-
input_ids = self.tokenizer.encode(obj.text)
|
127
|
-
else:
|
128
|
-
input_ids = obj.input_ids
|
129
125
|
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
126
|
+
if is_single:
|
127
|
+
async for response in self._handle_single_request(obj, request):
|
128
|
+
yield response
|
129
|
+
else:
|
130
|
+
if obj.stream:
|
131
|
+
raise ValueError("Do not support stream for batch mode.")
|
135
132
|
|
136
|
-
|
137
|
-
|
138
|
-
sampling_params.normalize(self.tokenizer)
|
139
|
-
sampling_params.verify()
|
133
|
+
async for response in self._handle_batch_request(obj, request):
|
134
|
+
yield response
|
140
135
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
147
|
-
obj.image_data
|
148
|
-
)
|
136
|
+
async def _handle_single_request(self, obj, request, index=None, is_prefill=False):
|
137
|
+
if is_prefill:
|
138
|
+
if isinstance(obj.text, list):
|
139
|
+
input_text = obj.text[index]
|
140
|
+
rid = obj.rid[index]
|
149
141
|
else:
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
142
|
+
input_text = obj.text
|
143
|
+
rid = obj.rid[0]
|
144
|
+
input_ids = self.tokenizer.encode(input_text)
|
145
|
+
sampling_params = SamplingParams(**obj.sampling_params[0])
|
146
|
+
sampling_params.max_new_tokens = 0
|
147
|
+
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
148
|
+
obj.image_data[0]
|
149
|
+
)
|
150
|
+
return_logprob = obj.return_logprob[0]
|
151
|
+
logprob_start_len = obj.logprob_start_len[0]
|
152
|
+
top_logprobs_num = obj.top_logprobs_num[0]
|
153
|
+
else:
|
154
|
+
rid = obj.rid if index is None else obj.rid[index]
|
155
|
+
input_text = obj.text if index is None else obj.text[index]
|
156
|
+
input_ids = (
|
157
|
+
self.tokenizer.encode(input_text)
|
158
|
+
if obj.input_ids is None
|
159
|
+
else obj.input_ids
|
163
160
|
)
|
164
|
-
|
161
|
+
if index is not None and obj.input_ids:
|
162
|
+
input_ids = obj.input_ids[index]
|
165
163
|
|
166
|
-
|
167
|
-
|
168
|
-
|
164
|
+
self._validate_input_length(input_ids)
|
165
|
+
sampling_params = self._get_sampling_params(
|
166
|
+
obj.sampling_params if index is None else obj.sampling_params[index]
|
167
|
+
)
|
168
|
+
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
169
|
+
obj.image_data if index is None else obj.image_data[index]
|
170
|
+
)
|
171
|
+
return_logprob = (
|
172
|
+
obj.return_logprob if index is None else obj.return_logprob[index]
|
173
|
+
)
|
174
|
+
logprob_start_len = (
|
175
|
+
obj.logprob_start_len if index is None else obj.logprob_start_len[index]
|
176
|
+
)
|
177
|
+
top_logprobs_num = (
|
178
|
+
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
|
179
|
+
)
|
169
180
|
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
181
|
+
tokenized_obj = TokenizedGenerateReqInput(
|
182
|
+
rid,
|
183
|
+
input_text,
|
184
|
+
input_ids,
|
185
|
+
pixel_values,
|
186
|
+
image_hash,
|
187
|
+
image_size,
|
188
|
+
sampling_params,
|
189
|
+
return_logprob,
|
190
|
+
logprob_start_len,
|
191
|
+
top_logprobs_num,
|
192
|
+
obj.stream,
|
193
|
+
)
|
194
|
+
self.send_to_router.send_pyobj(tokenized_obj)
|
195
|
+
|
196
|
+
event = asyncio.Event()
|
197
|
+
state = ReqState([], False, event)
|
198
|
+
self.rid_to_state[rid] = state
|
199
|
+
if is_prefill:
|
200
|
+
await self._wait_for_prefill_response(event, state, obj, request, rid)
|
201
|
+
yield input_ids
|
202
|
+
else:
|
203
|
+
async for response in self._wait_for_response(
|
204
|
+
event, state, obj, rid, request
|
205
|
+
):
|
206
|
+
yield response
|
207
|
+
|
208
|
+
async def _handle_batch_request(self, obj, request):
|
209
|
+
batch_size = obj.batch_size
|
210
|
+
parallel_sample_num = obj.sampling_params[0].get("n", 1)
|
211
|
+
|
212
|
+
if parallel_sample_num != 1:
|
213
|
+
## send prefill requests
|
214
|
+
parallel_sample_num += 1
|
215
|
+
input_id_result = [] if obj.input_ids is None else None
|
216
|
+
for i in range(batch_size):
|
217
|
+
async for input_id in self._handle_single_request(
|
218
|
+
obj, request, index=i, is_prefill=True
|
219
|
+
):
|
220
|
+
if input_id_result is not None:
|
221
|
+
input_id_result.append(input_id)
|
222
|
+
pass
|
223
|
+
if len(input_id_result) > 1 and input_id_result is not None:
|
224
|
+
obj.input_ids = input_id_result
|
225
|
+
elif input_id_result is not None:
|
226
|
+
obj.input_ids = input_id_result[0]
|
227
|
+
# First send out all requests
|
228
|
+
for i in range(batch_size):
|
229
|
+
for j in range(parallel_sample_num):
|
230
|
+
if j == 0 and parallel_sample_num != 1:
|
177
231
|
continue
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
232
|
+
index = i * parallel_sample_num + j
|
233
|
+
if parallel_sample_num != 1:
|
234
|
+
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
235
|
+
index += batch_size - 1 - i
|
236
|
+
rid = obj.rid[index]
|
237
|
+
if parallel_sample_num == 1:
|
238
|
+
## select operation
|
239
|
+
if obj.input_ids is None:
|
240
|
+
input_text = obj.text[i]
|
241
|
+
input_ids = self.tokenizer.encode(obj.text[i])
|
242
|
+
else:
|
243
|
+
input_text = None
|
244
|
+
input_ids = obj.input_ids[i]
|
245
|
+
else:
|
246
|
+
if batch_size == 1:
|
247
|
+
input_text = obj.text
|
248
|
+
input_ids = obj.input_ids
|
249
|
+
else:
|
250
|
+
input_text = obj.text[i]
|
251
|
+
input_ids = obj.input_ids[i]
|
252
|
+
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
253
|
+
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
254
|
+
obj.image_data[index]
|
184
255
|
)
|
185
256
|
|
186
|
-
if self.server_args.log_requests and state.finished:
|
187
|
-
logger.info(f"in={obj.text}, out={out}")
|
188
|
-
|
189
|
-
state.out_list = []
|
190
|
-
if state.finished:
|
191
|
-
del self.rid_to_state[rid]
|
192
|
-
|
193
|
-
yield out
|
194
|
-
|
195
|
-
break
|
196
|
-
|
197
|
-
event.clear()
|
198
|
-
|
199
|
-
yield out
|
200
|
-
else:
|
201
|
-
if obj.stream:
|
202
|
-
raise ValueError("Do not support stream for batch mode.")
|
203
|
-
|
204
|
-
if obj.input_ids is None:
|
205
|
-
bs = len(obj.text)
|
206
|
-
else:
|
207
|
-
bs = len(obj.input_ids)
|
208
|
-
|
209
|
-
for i in range(bs):
|
210
|
-
rid = obj.rid[i]
|
211
|
-
|
212
|
-
if obj.input_ids is None:
|
213
|
-
input_text = obj.text[i]
|
214
|
-
input_ids = self.tokenizer.encode(obj.text[i])
|
215
|
-
else:
|
216
|
-
input_text = None
|
217
|
-
input_ids = obj.input_ids[i]
|
218
|
-
|
219
|
-
sampling_params = SamplingParams(**obj.sampling_params[i])
|
220
|
-
if sampling_params.max_new_tokens != 0:
|
221
|
-
sampling_params.normalize(self.tokenizer)
|
222
|
-
sampling_params.verify()
|
223
|
-
if obj.image_data[i] is None:
|
224
|
-
pixel_values, image_hash, image_size = None, None, None
|
225
|
-
else:
|
226
|
-
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
227
|
-
obj.image_data[i]
|
228
|
-
)
|
229
257
|
tokenized_obj = TokenizedGenerateReqInput(
|
230
|
-
rid
|
231
|
-
input_text
|
232
|
-
input_ids
|
233
|
-
pixel_values
|
234
|
-
image_hash
|
235
|
-
image_size
|
236
|
-
sampling_params
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
258
|
+
rid,
|
259
|
+
input_text,
|
260
|
+
input_ids,
|
261
|
+
pixel_values,
|
262
|
+
image_hash,
|
263
|
+
image_size,
|
264
|
+
sampling_params,
|
265
|
+
obj.return_logprob[index],
|
266
|
+
obj.logprob_start_len[index],
|
267
|
+
obj.top_logprobs_num[index],
|
268
|
+
obj.stream,
|
241
269
|
)
|
242
270
|
self.send_to_router.send_pyobj(tokenized_obj)
|
243
271
|
|
@@ -245,9 +273,16 @@ class TokenizerManager:
|
|
245
273
|
state = ReqState([], False, event)
|
246
274
|
self.rid_to_state[rid] = state
|
247
275
|
|
248
|
-
|
249
|
-
|
250
|
-
|
276
|
+
# Then wait for all responses
|
277
|
+
output_list = []
|
278
|
+
for i in range(batch_size):
|
279
|
+
for j in range(parallel_sample_num):
|
280
|
+
if j == 0 and parallel_sample_num != 1:
|
281
|
+
continue
|
282
|
+
index = i * parallel_sample_num + j
|
283
|
+
if parallel_sample_num != 1:
|
284
|
+
index += batch_size - 1 - i
|
285
|
+
rid = obj.rid[index]
|
251
286
|
state = self.rid_to_state[rid]
|
252
287
|
|
253
288
|
while True:
|
@@ -260,19 +295,86 @@ class TokenizerManager:
|
|
260
295
|
self.abort_request(rid)
|
261
296
|
raise ValueError(f"Abort request {rid}")
|
262
297
|
continue
|
263
|
-
|
264
298
|
output_list.append(
|
265
299
|
self.convert_logprob_style(
|
266
300
|
state.out_list[-1],
|
267
|
-
obj.return_logprob[
|
268
|
-
obj.top_logprobs_num[
|
301
|
+
obj.return_logprob[index],
|
302
|
+
obj.top_logprobs_num[index],
|
269
303
|
obj.return_text_in_logprobs,
|
270
304
|
)
|
271
305
|
)
|
272
306
|
assert state.finished
|
273
307
|
del self.rid_to_state[rid]
|
274
308
|
|
275
|
-
|
309
|
+
yield output_list
|
310
|
+
|
311
|
+
def _validate_input_length(self, input_ids):
|
312
|
+
if len(input_ids) >= self.context_len:
|
313
|
+
raise ValueError(
|
314
|
+
f"The input ({len(input_ids)} tokens) is longer than the "
|
315
|
+
f"model's context length ({self.context_len} tokens)."
|
316
|
+
)
|
317
|
+
|
318
|
+
def _get_sampling_params(self, sampling_params_data, max_new_tokens=None):
|
319
|
+
sampling_params = SamplingParams(**sampling_params_data)
|
320
|
+
if max_new_tokens is not None:
|
321
|
+
sampling_params.max_new_tokens = max_new_tokens
|
322
|
+
if sampling_params.max_new_tokens != 0:
|
323
|
+
sampling_params.normalize(self.tokenizer)
|
324
|
+
sampling_params.verify()
|
325
|
+
return sampling_params
|
326
|
+
|
327
|
+
async def _get_pixel_values(self, image_data):
|
328
|
+
if isinstance(image_data, list) and len(image_data) > 0:
|
329
|
+
return await self.get_pixel_values(image_data[0])
|
330
|
+
elif isinstance(image_data, str):
|
331
|
+
return await self.get_pixel_values(image_data)
|
332
|
+
else:
|
333
|
+
return None, None, None
|
334
|
+
|
335
|
+
async def _wait_for_response(self, event, state, obj, rid, request):
|
336
|
+
while True:
|
337
|
+
try:
|
338
|
+
await asyncio.wait_for(event.wait(), timeout=4)
|
339
|
+
except asyncio.TimeoutError:
|
340
|
+
if request is not None and await request.is_disconnected():
|
341
|
+
self.abort_request(rid)
|
342
|
+
raise ValueError(f"Abort request {rid}")
|
343
|
+
continue
|
344
|
+
|
345
|
+
out = self.convert_logprob_style(
|
346
|
+
state.out_list[-1],
|
347
|
+
obj.return_logprob,
|
348
|
+
obj.top_logprobs_num,
|
349
|
+
obj.return_text_in_logprobs,
|
350
|
+
)
|
351
|
+
|
352
|
+
if self.server_args.log_requests and state.finished:
|
353
|
+
logger.info(f"in={obj.text}, out={out}")
|
354
|
+
|
355
|
+
state.out_list = []
|
356
|
+
if state.finished:
|
357
|
+
del self.rid_to_state[rid]
|
358
|
+
yield out
|
359
|
+
break
|
360
|
+
|
361
|
+
event.clear()
|
362
|
+
yield out
|
363
|
+
|
364
|
+
async def _wait_for_prefill_response(self, event, state, obj, request, rid):
|
365
|
+
while True:
|
366
|
+
try:
|
367
|
+
await asyncio.wait_for(state.event.wait(), timeout=4)
|
368
|
+
break
|
369
|
+
except asyncio.TimeoutError:
|
370
|
+
if request is not None and await request.is_disconnected():
|
371
|
+
for rid in obj.rid:
|
372
|
+
self.abort_request(rid)
|
373
|
+
raise ValueError(f"Abort request {rid}")
|
374
|
+
continue
|
375
|
+
|
376
|
+
assert state.finished
|
377
|
+
del self.rid_to_state[rid]
|
276
378
|
|
277
379
|
def flush_cache(self):
|
278
380
|
req = FlushCacheReq()
|
@@ -335,15 +437,16 @@ class TokenizerManager:
|
|
335
437
|
)
|
336
438
|
|
337
439
|
if top_logprobs_num > 0:
|
338
|
-
ret["meta_info"][
|
339
|
-
|
340
|
-
|
341
|
-
|
440
|
+
ret["meta_info"]["prefill_top_logprobs"] = (
|
441
|
+
self.detokenize_top_logprobs_tokens(
|
442
|
+
ret["meta_info"]["prefill_top_logprobs"],
|
443
|
+
return_text_in_logprobs,
|
444
|
+
)
|
342
445
|
)
|
343
|
-
ret["meta_info"][
|
344
|
-
|
345
|
-
|
346
|
-
|
446
|
+
ret["meta_info"]["decode_top_logprobs"] = (
|
447
|
+
self.detokenize_top_logprobs_tokens(
|
448
|
+
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
449
|
+
)
|
347
450
|
)
|
348
451
|
return ret
|
349
452
|
|
sglang/srt/memory_pool.py
CHANGED
@@ -21,7 +21,9 @@ class ReqToTokenPool:
|
|
21
21
|
if need_size > self.can_use_mem_size:
|
22
22
|
return None
|
23
23
|
|
24
|
-
select_index =
|
24
|
+
select_index = (
|
25
|
+
torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
|
26
|
+
)
|
25
27
|
self.mem_state[select_index] = False
|
26
28
|
self.can_use_mem_size -= need_size
|
27
29
|
|
@@ -42,7 +44,14 @@ class ReqToTokenPool:
|
|
42
44
|
class TokenToKVPool:
|
43
45
|
"""A memory pool that maps a token to its kv cache locations"""
|
44
46
|
|
45
|
-
def __init__(
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
size: int,
|
50
|
+
dtype: torch.dtype,
|
51
|
+
head_num: int,
|
52
|
+
head_dim: int,
|
53
|
+
layer_num: int,
|
54
|
+
):
|
46
55
|
self.size = size
|
47
56
|
|
48
57
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
@@ -61,16 +70,16 @@ class TokenToKVPool:
|
|
61
70
|
self.can_use_mem_size = self.size
|
62
71
|
self.clear()
|
63
72
|
|
64
|
-
def get_key_buffer(self, layer_id):
|
73
|
+
def get_key_buffer(self, layer_id: int):
|
65
74
|
return self.kv_data[layer_id][:, 0]
|
66
75
|
|
67
|
-
def get_value_buffer(self, layer_id):
|
76
|
+
def get_value_buffer(self, layer_id: int):
|
68
77
|
return self.kv_data[layer_id][:, 1]
|
69
78
|
|
70
79
|
def available_size(self):
|
71
80
|
return self.can_use_mem_size + len(self.prefetch_buffer)
|
72
81
|
|
73
|
-
def alloc(self, need_size):
|
82
|
+
def alloc(self, need_size: int):
|
74
83
|
buffer_len = len(self.prefetch_buffer)
|
75
84
|
if need_size <= buffer_len:
|
76
85
|
select_index = self.prefetch_buffer[:need_size]
|
@@ -79,7 +88,9 @@ class TokenToKVPool:
|
|
79
88
|
|
80
89
|
addition_size = need_size - buffer_len
|
81
90
|
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
82
|
-
select_index =
|
91
|
+
select_index = (
|
92
|
+
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
|
93
|
+
)
|
83
94
|
|
84
95
|
if select_index.shape[0] < addition_size:
|
85
96
|
return None
|
@@ -98,6 +109,8 @@ class TokenToKVPool:
|
|
98
109
|
self.can_use_mem_size += len(free_index)
|
99
110
|
|
100
111
|
def clear(self):
|
112
|
+
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
113
|
+
|
101
114
|
self.mem_state.fill_(True)
|
102
115
|
self.can_use_mem_size = self.size
|
103
116
|
|