sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,196 @@
1
+ """Run the model with cuda graph."""
2
+
3
+ import bisect
4
+
5
+ import torch
6
+ from vllm.distributed.parallel_state import graph_capture
7
+
8
+ from sglang.global_config import global_config
9
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
10
+ from sglang.srt.managers.controller.infer_batch import (
11
+ Batch,
12
+ ForwardMode,
13
+ InputMetadata,
14
+ init_flashinfer_args,
15
+ )
16
+
17
+
18
+ class CudaGraphRunner:
19
+ def __init__(self, model_runner, max_batch_size_to_capture):
20
+ self.model_runner = model_runner
21
+ self.graphs = {}
22
+ self.input_buffers = {}
23
+ self.output_buffers = {}
24
+ self.flashinfer_handlers = {}
25
+ self.graph_memory_pool = None
26
+
27
+ # Common inputs
28
+ self.max_bs = max_batch_size_to_capture
29
+ self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
30
+ self.req_pool_indices = torch.zeros(
31
+ (self.max_bs,), dtype=torch.int32, device="cuda"
32
+ )
33
+ self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
34
+ self.position_ids_offsets = torch.zeros(
35
+ (self.max_bs,), dtype=torch.int32, device="cuda"
36
+ )
37
+ self.out_cache_loc = torch.zeros(
38
+ (self.max_bs,), dtype=torch.int32, device="cuda"
39
+ )
40
+
41
+ # FlashInfer inputs
42
+ self.flashinfer_workspace_buffer = (
43
+ self.model_runner.flashinfer_workspace_buffers[0]
44
+ )
45
+ self.flashinfer_kv_indptr = torch.zeros(
46
+ (self.max_bs + 1,), dtype=torch.int32, device="cuda"
47
+ )
48
+ self.flashinfer_kv_indices = torch.zeros(
49
+ (self.max_bs * model_runner.model_config.context_len,),
50
+ dtype=torch.int32,
51
+ device="cuda",
52
+ )
53
+ self.flashinfer_kv_last_page_len = torch.ones(
54
+ (self.max_bs,), dtype=torch.int32, device="cuda"
55
+ )
56
+
57
+ def can_run(self, batch_size):
58
+ return batch_size < self.max_bs
59
+
60
+ def capture(self, batch_size_list):
61
+ self.batch_size_list = batch_size_list
62
+ with graph_capture() as graph_capture_context:
63
+ self.stream = graph_capture_context.stream
64
+ for bs in batch_size_list:
65
+ (
66
+ graph,
67
+ input_buffers,
68
+ output_buffers,
69
+ flashinfer_handler,
70
+ ) = self.capture_one_batch_size(bs)
71
+ self.graphs[bs] = graph
72
+ self.input_buffers[bs] = input_buffers
73
+ self.output_buffers[bs] = output_buffers
74
+ self.flashinfer_handlers[bs] = flashinfer_handler
75
+
76
+ def capture_one_batch_size(self, bs):
77
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
78
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
79
+
80
+ graph = torch.cuda.CUDAGraph()
81
+ stream = self.stream
82
+
83
+ # Common inputs
84
+ input_ids = self.input_ids[:bs]
85
+ req_pool_indices = self.req_pool_indices[:bs]
86
+ seq_lens = self.seq_lens[:bs]
87
+ position_ids_offsets = self.position_ids_offsets[:bs]
88
+ out_cache_loc = self.out_cache_loc[:bs]
89
+
90
+ # FlashInfer inputs
91
+ if not _grouped_size_compiled_for_decode_kernels(
92
+ self.model_runner.model_config.num_attention_heads
93
+ // self.model_runner.tp_size,
94
+ self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
95
+ ):
96
+ use_tensor_cores = True
97
+ else:
98
+ use_tensor_cores = False
99
+ flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
100
+ self.flashinfer_workspace_buffer,
101
+ "NHD",
102
+ use_cuda_graph=True,
103
+ use_tensor_cores=use_tensor_cores,
104
+ paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
105
+ paged_kv_indices_buffer=self.flashinfer_kv_indices,
106
+ paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
107
+ )
108
+ init_flashinfer_args(
109
+ ForwardMode.DECODE,
110
+ self.model_runner,
111
+ req_pool_indices,
112
+ seq_lens,
113
+ None,
114
+ flashinfer_decode_wrapper,
115
+ )
116
+
117
+ # Run and capture
118
+ def run_once():
119
+ input_metadata = InputMetadata.create(
120
+ self.model_runner,
121
+ forward_mode=ForwardMode.DECODE,
122
+ req_pool_indices=req_pool_indices,
123
+ seq_lens=seq_lens,
124
+ prefix_lens=None,
125
+ position_ids_offsets=position_ids_offsets,
126
+ out_cache_loc=out_cache_loc,
127
+ return_logprob=False,
128
+ top_logprobs_nums=0,
129
+ skip_flashinfer_init=True,
130
+ )
131
+ input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
132
+ return self.model_runner.model.forward(
133
+ input_ids, input_metadata.positions, input_metadata
134
+ )
135
+
136
+ for _ in range(2):
137
+ run_once()
138
+
139
+ torch.cuda.synchronize()
140
+ with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
141
+ out = run_once()
142
+ torch.cuda.synchronize()
143
+ self.graph_memory_pool = graph.pool()
144
+ return graph, None, out, flashinfer_decode_wrapper
145
+
146
+ def replay(self, batch: Batch):
147
+ assert batch.out_cache_loc is not None
148
+ assert not batch.return_logprob
149
+ raw_bs = len(batch.reqs)
150
+
151
+ # Pad
152
+ index = bisect.bisect_left(self.batch_size_list, raw_bs)
153
+ bs = self.batch_size_list[index]
154
+ if bs != raw_bs:
155
+ self.seq_lens.zero_()
156
+ self.position_ids_offsets.fill_(1)
157
+ self.out_cache_loc.zero_()
158
+
159
+ # Common inputs
160
+ self.input_ids[:raw_bs] = batch.input_ids
161
+ self.req_pool_indices[:raw_bs] = batch.req_pool_indices
162
+ self.seq_lens[:raw_bs] = batch.seq_lens
163
+ self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
164
+ self.out_cache_loc[:raw_bs] = batch.out_cache_loc
165
+
166
+ # FlashInfer inputs
167
+ init_flashinfer_args(
168
+ ForwardMode.DECODE,
169
+ self.model_runner,
170
+ self.req_pool_indices[:bs],
171
+ self.seq_lens[:bs],
172
+ None,
173
+ self.flashinfer_handlers[bs],
174
+ )
175
+
176
+ # Replay
177
+ self.graphs[bs].replay()
178
+ output = self.output_buffers[bs]
179
+
180
+ # Unpad
181
+ if bs == raw_bs:
182
+ return output
183
+ else:
184
+ output = LogitProcessorOutput(
185
+ next_token_logits=output.next_token_logits[:raw_bs],
186
+ next_token_logprobs=output.next_token_logprobs[:raw_bs]
187
+ if output.next_token_logprobs is not None
188
+ else None,
189
+ normalized_prompt_logprobs=None,
190
+ prefill_token_logprobs=None,
191
+ prefill_top_logprobs=None,
192
+ decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
193
+ if output.decode_top_logprobs is not None
194
+ else None,
195
+ )
196
+ return output
@@ -0,0 +1,113 @@
1
+ """A data parallel worker thread."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import queue
6
+ import threading
7
+ from typing import Callable, List
8
+
9
+ import uvloop
10
+ import zmq
11
+
12
+ from sglang.global_config import global_config
13
+ from sglang.srt.managers.controller.tp_worker import ModelTpClient
14
+ from sglang.srt.managers.io_struct import BatchTokenIDOut
15
+ from sglang.srt.server_args import PortArgs, ServerArgs
16
+ from sglang.srt.utils import kill_parent_process
17
+ from sglang.utils import get_exception_traceback
18
+
19
+ logger = logging.getLogger("srt.controller")
20
+ CHECKING_INTERVAL = 5
21
+
22
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
23
+
24
+
25
+ class DataParallelWorkerThread(threading.Thread):
26
+ def __init__(
27
+ self,
28
+ worker_id: int,
29
+ request_queue: queue.Queue,
30
+ detokenizer_port: int,
31
+ step_func: Callable,
32
+ ):
33
+ super(DataParallelWorkerThread, self).__init__()
34
+ self.worker_id = worker_id
35
+ self.request_queue = request_queue
36
+ self.liveness = True
37
+ self.request_dependency_delay = global_config.request_dependency_delay
38
+
39
+ context = zmq.asyncio.Context()
40
+ self.send_to_detokenizer = context.socket(zmq.PUSH)
41
+ self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
42
+
43
+ self.step = step_func
44
+
45
+ async def loop_for_forward(self):
46
+ while self.liveness:
47
+ requests = []
48
+ while not self.request_queue.empty():
49
+ requests.append(self.request_queue.get())
50
+
51
+ out_pyobjs: List[BatchTokenIDOut] = []
52
+ try:
53
+ out_pyobjs = await self.step(requests)
54
+ except Exception:
55
+ for r in requests:
56
+ self.request_queue.put(r)
57
+ logger.error(
58
+ f"Worker thread {self.worker_id}: "
59
+ f"failed to get back from Model Server\n"
60
+ f"{get_exception_traceback()}"
61
+ )
62
+ self.liveness = False
63
+ # Crash the whole server when there are any errors.
64
+ # TODO(lianmin): make this an option.
65
+ kill_parent_process()
66
+ return
67
+
68
+ for obj in out_pyobjs:
69
+ self.send_to_detokenizer.send_pyobj(obj)
70
+
71
+ # async sleep for receiving the subsequent request and avoiding cache miss
72
+ if len(out_pyobjs) != 0:
73
+ has_finished = any(
74
+ [obj.finished_reason is not None for obj in out_pyobjs]
75
+ )
76
+ if has_finished:
77
+ await asyncio.sleep(self.request_dependency_delay)
78
+ await asyncio.sleep(global_config.wait_for_new_request_delay)
79
+
80
+ async def monitoring(self):
81
+ while True:
82
+ await asyncio.sleep(CHECKING_INTERVAL)
83
+ # can plug in monitoring logic here
84
+
85
+ def run(self):
86
+ logger.info(f"DataParallelWorkerThread {self.worker_id} start")
87
+ loop = asyncio.new_event_loop()
88
+ asyncio.set_event_loop(loop)
89
+ loop.create_task(self.monitoring())
90
+ loop.run_until_complete(self.loop_for_forward())
91
+
92
+
93
+ def start_data_parallel_worker(
94
+ server_args: ServerArgs,
95
+ port_args: PortArgs,
96
+ model_overide_args,
97
+ gpu_ids: List[int],
98
+ worker_id: int,
99
+ ):
100
+ model_tp_client = ModelTpClient(
101
+ gpu_ids,
102
+ server_args,
103
+ port_args.model_port_args[worker_id],
104
+ model_overide_args,
105
+ )
106
+ worker_thread = DataParallelWorkerThread(
107
+ worker_id=worker_id,
108
+ request_queue=queue.Queue(),
109
+ detokenizer_port=port_args.detokenizer_port,
110
+ step_func=model_tp_client.step,
111
+ )
112
+ worker_thread.start()
113
+ return worker_thread