sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -139,7 +139,7 @@ def load_model(server_args, port_args, tp_rank):
|
|
139
139
|
gpu_id=tp_rank,
|
140
140
|
tp_rank=tp_rank,
|
141
141
|
tp_size=server_args.tp_size,
|
142
|
-
nccl_port=port_args.
|
142
|
+
nccl_port=port_args.nccl_port,
|
143
143
|
server_args=server_args,
|
144
144
|
)
|
145
145
|
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|
220
220
|
return reqs
|
221
221
|
|
222
222
|
|
223
|
+
@torch.inference_mode()
|
223
224
|
def extend(reqs, model_runner):
|
224
225
|
batch = ScheduleBatch.init_new(
|
225
226
|
reqs=reqs,
|
@@ -231,26 +232,28 @@ def extend(reqs, model_runner):
|
|
231
232
|
model_worker_batch = batch.get_model_worker_batch()
|
232
233
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
233
234
|
logits_output = model_runner.forward(forward_batch)
|
234
|
-
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
235
|
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
235
236
|
return next_token_ids, logits_output.next_token_logits, batch
|
236
237
|
|
237
238
|
|
239
|
+
@torch.inference_mode()
|
238
240
|
def decode(input_token_ids, batch, model_runner):
|
239
|
-
batch.
|
241
|
+
batch.output_ids = input_token_ids
|
242
|
+
batch.prepare_for_decode()
|
240
243
|
model_worker_batch = batch.get_model_worker_batch()
|
241
244
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
242
245
|
logits_output = model_runner.forward(forward_batch)
|
243
|
-
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
246
|
+
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
244
247
|
return next_token_ids, logits_output.next_token_logits
|
245
248
|
|
246
249
|
|
247
|
-
@torch.inference_mode()
|
248
250
|
def correctness_test(
|
249
251
|
server_args,
|
250
252
|
port_args,
|
251
253
|
bench_args,
|
252
254
|
tp_rank,
|
253
255
|
):
|
256
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
254
257
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
255
258
|
|
256
259
|
# Load the model
|
@@ -278,8 +281,9 @@ def correctness_test(
|
|
278
281
|
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
279
282
|
for _ in range(bench_args.output_len[0] - 1):
|
280
283
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
284
|
+
next_token_ids_list = next_token_ids.tolist()
|
281
285
|
for i in range(len(reqs)):
|
282
|
-
output_ids[i].append(
|
286
|
+
output_ids[i].append(next_token_ids_list[i])
|
283
287
|
|
284
288
|
# Print
|
285
289
|
for i in range(len(reqs)):
|
@@ -287,9 +291,15 @@ def correctness_test(
|
|
287
291
|
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
288
292
|
|
289
293
|
|
290
|
-
|
294
|
+
def synchronize(device):
|
295
|
+
if device == "cuda":
|
296
|
+
torch.cuda.synchronize()
|
297
|
+
elif device == "xpu":
|
298
|
+
torch.xpu.synchronize()
|
299
|
+
|
300
|
+
|
291
301
|
def latency_test_run_once(
|
292
|
-
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
|
302
|
+
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
|
293
303
|
):
|
294
304
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
295
305
|
if batch_size > max_batch_size:
|
@@ -312,10 +322,10 @@ def latency_test_run_once(
|
|
312
322
|
tot_latency = 0
|
313
323
|
|
314
324
|
# Prefill
|
315
|
-
|
325
|
+
synchronize(device)
|
316
326
|
tic = time.time()
|
317
327
|
next_token_ids, _, batch = extend(reqs, model_runner)
|
318
|
-
|
328
|
+
synchronize(device)
|
319
329
|
prefill_latency = time.time() - tic
|
320
330
|
tot_latency += prefill_latency
|
321
331
|
throughput = input_len * batch_size / prefill_latency
|
@@ -328,10 +338,10 @@ def latency_test_run_once(
|
|
328
338
|
# Decode
|
329
339
|
decode_latencies = []
|
330
340
|
for i in range(output_len - 1):
|
331
|
-
|
341
|
+
synchronize(device)
|
332
342
|
tic = time.time()
|
333
343
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
334
|
-
|
344
|
+
synchronize(device)
|
335
345
|
latency = time.time() - tic
|
336
346
|
tot_latency += latency
|
337
347
|
throughput = batch_size / latency
|
@@ -387,6 +397,7 @@ def latency_test(
|
|
387
397
|
bench_args.batch_size[0],
|
388
398
|
bench_args.input_len[0],
|
389
399
|
8, # shorter decoding to speed up the warmup
|
400
|
+
server_args.device,
|
390
401
|
)
|
391
402
|
rank_print("Benchmark ...")
|
392
403
|
|
@@ -397,7 +408,14 @@ def latency_test(
|
|
397
408
|
):
|
398
409
|
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
399
410
|
ret = latency_test_run_once(
|
400
|
-
bench_args.run_name,
|
411
|
+
bench_args.run_name,
|
412
|
+
model_runner,
|
413
|
+
rank_print,
|
414
|
+
reqs,
|
415
|
+
bs,
|
416
|
+
il,
|
417
|
+
ol,
|
418
|
+
server_args.device,
|
401
419
|
)
|
402
420
|
if ret is not None:
|
403
421
|
result_list.append(ret)
|
sglang/bench_server_latency.py
CHANGED
@@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py.
|
|
6
6
|
Usage:
|
7
7
|
|
8
8
|
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
9
|
+
|
10
|
+
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
9
11
|
"""
|
10
12
|
|
11
13
|
import argparse
|
@@ -32,6 +34,8 @@ class BenchArgs:
|
|
32
34
|
input_len: Tuple[int] = (1024,)
|
33
35
|
output_len: Tuple[int] = (16,)
|
34
36
|
result_filename: str = "result.jsonl"
|
37
|
+
base_url: str = ""
|
38
|
+
skip_warmup: bool = False
|
35
39
|
|
36
40
|
@staticmethod
|
37
41
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -48,6 +52,8 @@ class BenchArgs:
|
|
48
52
|
parser.add_argument(
|
49
53
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
50
54
|
)
|
55
|
+
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
56
|
+
parser.add_argument("--skip-warmup", action="store_true")
|
51
57
|
|
52
58
|
@classmethod
|
53
59
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -139,17 +145,21 @@ def run_one_case(
|
|
139
145
|
|
140
146
|
|
141
147
|
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
142
|
-
|
148
|
+
if bench_args.base_url:
|
149
|
+
proc, base_url = None, bench_args.base_url
|
150
|
+
else:
|
151
|
+
proc, base_url = launch_server_process(server_args)
|
143
152
|
|
144
153
|
# warmup
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
154
|
+
if not bench_args.skip_warmup:
|
155
|
+
run_one_case(
|
156
|
+
base_url,
|
157
|
+
batch_size=16,
|
158
|
+
input_len=1024,
|
159
|
+
output_len=16,
|
160
|
+
run_name="",
|
161
|
+
result_filename="",
|
162
|
+
)
|
153
163
|
|
154
164
|
# benchmark
|
155
165
|
try:
|
@@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
165
175
|
bench_args.result_filename,
|
166
176
|
)
|
167
177
|
finally:
|
168
|
-
|
178
|
+
if proc:
|
179
|
+
kill_child_process(proc.pid)
|
169
180
|
|
170
181
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
171
182
|
|
sglang/bench_serving.py
CHANGED
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
|
|
222
222
|
return output
|
223
223
|
|
224
224
|
|
225
|
+
async def async_request_sglang_generate(
|
226
|
+
request_func_input: RequestFuncInput,
|
227
|
+
pbar: Optional[tqdm] = None,
|
228
|
+
) -> RequestFuncOutput:
|
229
|
+
api_url = request_func_input.api_url
|
230
|
+
prompt = request_func_input.prompt
|
231
|
+
|
232
|
+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
233
|
+
payload = {
|
234
|
+
"text": prompt,
|
235
|
+
"sampling_params": {
|
236
|
+
"temperature": 0.0,
|
237
|
+
"max_new_tokens": request_func_input.output_len,
|
238
|
+
"ignore_eos": not args.disable_ignore_eos,
|
239
|
+
},
|
240
|
+
"stream": not args.disable_stream,
|
241
|
+
**request_func_input.extra_request_body,
|
242
|
+
}
|
243
|
+
headers = {}
|
244
|
+
|
245
|
+
output = RequestFuncOutput()
|
246
|
+
output.prompt_len = request_func_input.prompt_len
|
247
|
+
|
248
|
+
generated_text = ""
|
249
|
+
ttft = 0.0
|
250
|
+
st = time.perf_counter()
|
251
|
+
most_recent_timestamp = st
|
252
|
+
try:
|
253
|
+
async with session.post(
|
254
|
+
url=api_url, json=payload, headers=headers
|
255
|
+
) as response:
|
256
|
+
if response.status == 200:
|
257
|
+
async for chunk_bytes in response.content:
|
258
|
+
chunk_bytes = chunk_bytes.strip()
|
259
|
+
if not chunk_bytes:
|
260
|
+
continue
|
261
|
+
# print(chunk_bytes)
|
262
|
+
|
263
|
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
264
|
+
latency = time.perf_counter() - st
|
265
|
+
if chunk == "[DONE]":
|
266
|
+
pass
|
267
|
+
else:
|
268
|
+
data = json.loads(chunk)
|
269
|
+
|
270
|
+
# NOTE: Some completion API might have a last
|
271
|
+
# usage summary response without a token so we
|
272
|
+
# want to check a token was generated
|
273
|
+
if data["text"]:
|
274
|
+
timestamp = time.perf_counter()
|
275
|
+
# First token
|
276
|
+
if ttft == 0.0:
|
277
|
+
ttft = time.perf_counter() - st
|
278
|
+
output.ttft = ttft
|
279
|
+
|
280
|
+
# Decoding phase
|
281
|
+
else:
|
282
|
+
output.itl.append(timestamp - most_recent_timestamp)
|
283
|
+
|
284
|
+
most_recent_timestamp = timestamp
|
285
|
+
generated_text = data["text"]
|
286
|
+
|
287
|
+
output.generated_text = generated_text
|
288
|
+
output.success = True
|
289
|
+
output.latency = latency
|
290
|
+
output.output_len = request_func_input.output_len
|
291
|
+
else:
|
292
|
+
output.error = response.reason or ""
|
293
|
+
output.success = False
|
294
|
+
except Exception:
|
295
|
+
output.success = False
|
296
|
+
exc_info = sys.exc_info()
|
297
|
+
output.error = "".join(traceback.format_exception(*exc_info))
|
298
|
+
|
299
|
+
if pbar:
|
300
|
+
pbar.update(1)
|
301
|
+
return output
|
302
|
+
|
303
|
+
|
225
304
|
async def async_request_gserver(
|
226
305
|
request_func_input: RequestFuncInput,
|
227
306
|
pbar: Optional[tqdm] = None,
|
@@ -264,7 +343,9 @@ def get_tokenizer(
|
|
264
343
|
|
265
344
|
|
266
345
|
ASYNC_REQUEST_FUNCS = {
|
267
|
-
"sglang":
|
346
|
+
"sglang": async_request_sglang_generate,
|
347
|
+
"sglang-native": async_request_sglang_generate,
|
348
|
+
"sglang-oai": async_request_openai_completions,
|
268
349
|
"vllm": async_request_openai_completions,
|
269
350
|
"lmdeploy": async_request_openai_completions,
|
270
351
|
"trt": async_request_trt_llm,
|
@@ -387,6 +468,8 @@ def sample_sharegpt_requests(
|
|
387
468
|
continue
|
388
469
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
389
470
|
|
471
|
+
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
472
|
+
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
|
390
473
|
return filtered_dataset
|
391
474
|
|
392
475
|
|
@@ -587,6 +670,8 @@ async def benchmark(
|
|
587
670
|
else:
|
588
671
|
print("Initial test run completed. Starting main benchmark run...")
|
589
672
|
|
673
|
+
time.sleep(1.5)
|
674
|
+
|
590
675
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
591
676
|
|
592
677
|
benchmark_start_time = time.perf_counter()
|
@@ -782,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace):
|
|
782
867
|
if args.port is None:
|
783
868
|
args.port = {
|
784
869
|
"sglang": 30000,
|
870
|
+
"sglang-native": 30000,
|
871
|
+
"sglang-oai": 30000,
|
785
872
|
"lmdeploy": 23333,
|
786
873
|
"vllm": 8000,
|
787
874
|
"trt": 8000,
|
788
875
|
"gserver": 9988,
|
789
876
|
}.get(args.backend, 30000)
|
790
877
|
|
791
|
-
api_url = (
|
792
|
-
f"{args.base_url}/v1/completions"
|
793
|
-
if args.base_url
|
794
|
-
else f"http://{args.host}:{args.port}/v1/completions"
|
795
|
-
)
|
796
878
|
model_url = (
|
797
879
|
f"{args.base_url}/v1/models"
|
798
880
|
if args.base_url
|
799
881
|
else f"http://{args.host}:{args.port}/v1/models"
|
800
882
|
)
|
801
883
|
|
802
|
-
if args.backend
|
884
|
+
if args.backend in ["sglang", "sglang-native"]:
|
885
|
+
api_url = (
|
886
|
+
f"{args.base_url}/generate"
|
887
|
+
if args.base_url
|
888
|
+
else f"http://{args.host}:{args.port}/generate"
|
889
|
+
)
|
890
|
+
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
|
891
|
+
api_url = (
|
892
|
+
f"{args.base_url}/v1/completions"
|
893
|
+
if args.base_url
|
894
|
+
else f"http://{args.host}:{args.port}/v1/completions"
|
895
|
+
)
|
896
|
+
elif args.backend == "trt":
|
803
897
|
api_url = (
|
804
898
|
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
805
899
|
if args.base_url
|
sglang/global_config.py
CHANGED
@@ -19,7 +19,6 @@ class GlobalConfig:
|
|
19
19
|
self.new_token_ratio_decay = 0.001
|
20
20
|
|
21
21
|
# Runtime constants: others
|
22
|
-
self.num_continue_decode_steps = 10
|
23
22
|
self.retract_decode_steps = 20
|
24
23
|
self.flashinfer_workspace_size = os.environ.get(
|
25
24
|
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
sglang/srt/conversation.py
CHANGED
@@ -70,6 +70,9 @@ class Conversation:
|
|
70
70
|
sep2: str = None
|
71
71
|
# Stop criteria (the default one is EOS token)
|
72
72
|
stop_str: Union[str, List[str]] = None
|
73
|
+
# The string that represents an image token in the prompt
|
74
|
+
image_token: str = "<image>"
|
75
|
+
|
73
76
|
image_data: Optional[List[str]] = None
|
74
77
|
modalities: Optional[List[str]] = None
|
75
78
|
|
@@ -334,6 +337,7 @@ class Conversation:
|
|
334
337
|
sep=self.sep,
|
335
338
|
sep2=self.sep2,
|
336
339
|
stop_str=self.stop_str,
|
340
|
+
image_token=self.image_token,
|
337
341
|
)
|
338
342
|
|
339
343
|
def dict(self):
|
@@ -381,6 +385,7 @@ def generate_chat_conv(
|
|
381
385
|
stop_str=conv.stop_str,
|
382
386
|
image_data=[],
|
383
387
|
modalities=[],
|
388
|
+
image_token=conv.image_token,
|
384
389
|
)
|
385
390
|
|
386
391
|
if isinstance(request.messages, str):
|
@@ -412,9 +417,13 @@ def generate_chat_conv(
|
|
412
417
|
num_image_url += 1
|
413
418
|
conv.modalities.append(content.modalities)
|
414
419
|
if num_image_url > 1:
|
415
|
-
image_token =
|
420
|
+
image_token = conv.image_token
|
416
421
|
else:
|
417
|
-
image_token =
|
422
|
+
image_token = (
|
423
|
+
conv.image_token + "\n"
|
424
|
+
if conv.name != "qwen2-vl"
|
425
|
+
else conv.image_token
|
426
|
+
)
|
418
427
|
for content in message.content:
|
419
428
|
if content.type == "text":
|
420
429
|
if num_image_url > 16:
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
|
3
|
+
import torch
|
3
4
|
from torch import nn
|
4
5
|
|
5
6
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -18,13 +19,13 @@ class AttentionBackend(ABC):
|
|
18
19
|
raise NotImplementedError()
|
19
20
|
|
20
21
|
def init_forward_metadata_capture_cuda_graph(
|
21
|
-
self, bs: int, req_pool_indices, seq_lens
|
22
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
22
23
|
):
|
23
24
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
24
25
|
raise NotImplementedError()
|
25
26
|
|
26
27
|
def init_forward_metadata_replay_cuda_graph(
|
27
|
-
self, bs: int, req_pool_indices, seq_lens
|
28
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
28
29
|
):
|
29
30
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
30
31
|
raise NotImplementedError()
|
@@ -33,17 +34,38 @@ class AttentionBackend(ABC):
|
|
33
34
|
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
34
35
|
raise NotImplementedError()
|
35
36
|
|
36
|
-
def forward(
|
37
|
+
def forward(
|
38
|
+
self,
|
39
|
+
q: torch.Tensor,
|
40
|
+
k: torch.Tensor,
|
41
|
+
v: torch.Tensor,
|
42
|
+
layer: nn.Module,
|
43
|
+
forward_batch: ForwardBatch,
|
44
|
+
):
|
37
45
|
"""Run forward on an attention layer."""
|
38
46
|
if forward_batch.forward_mode.is_decode():
|
39
47
|
return self.forward_decode(q, k, v, layer, forward_batch)
|
40
48
|
else:
|
41
49
|
return self.forward_extend(q, k, v, layer, forward_batch)
|
42
50
|
|
43
|
-
def forward_decode(
|
51
|
+
def forward_decode(
|
52
|
+
self,
|
53
|
+
q: torch.Tensor,
|
54
|
+
k: torch.Tensor,
|
55
|
+
v: torch.Tensor,
|
56
|
+
layer: nn.Module,
|
57
|
+
forward_batch: ForwardBatch,
|
58
|
+
):
|
44
59
|
"""Run a forward for decode."""
|
45
60
|
raise NotImplementedError()
|
46
61
|
|
47
|
-
def forward_extend(
|
62
|
+
def forward_extend(
|
63
|
+
self,
|
64
|
+
q: torch.Tensor,
|
65
|
+
k: torch.Tensor,
|
66
|
+
v: torch.Tensor,
|
67
|
+
layer: nn.Module,
|
68
|
+
forward_batch: ForwardBatch,
|
69
|
+
):
|
48
70
|
"""Run a forward for extend."""
|
49
71
|
raise NotImplementedError()
|