sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -1,482 +0,0 @@
1
- from __future__ import annotations
2
-
3
- """
4
- Support different attention backends.
5
- Now there are two backends: FlashInfer and Triton.
6
- FlashInfer is faster and Triton is easier to customize.
7
- Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
- """
9
-
10
- from abc import ABC, abstractmethod
11
- from typing import TYPE_CHECKING
12
-
13
- import torch
14
- import torch.nn as nn
15
-
16
- from sglang.global_config import global_config
17
- from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
18
- from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
19
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
20
- from sglang.srt.utils import is_hip
21
-
22
- if TYPE_CHECKING:
23
- from sglang.srt.model_executor.model_runner import ModelRunner
24
-
25
- # ROCm: flashinfer available later
26
- if not is_hip():
27
- from flashinfer import (
28
- BatchDecodeWithPagedKVCacheWrapper,
29
- BatchPrefillWithPagedKVCacheWrapper,
30
- BatchPrefillWithRaggedKVCacheWrapper,
31
- )
32
- from flashinfer.cascade import merge_state
33
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
-
35
-
36
- class AttentionBackend(ABC):
37
- """The base class of attention backends"""
38
-
39
- @abstractmethod
40
- def init_forward_metadata(
41
- self, batch: ScheduleBatch, input_metadata: InputMetadata
42
- ):
43
- """Init the metadata for a forward pass."""
44
- raise NotImplementedError()
45
-
46
- def init_cuda_graph_state(self, max_bs: int):
47
- """Init the global shared states for cuda graph."""
48
- raise NotImplementedError()
49
-
50
- def init_forward_metadata_capture_cuda_graph(
51
- self, bs: int, req_pool_indices, seq_lens
52
- ):
53
- """Init the metadata for a forward pass for capturing a cuda graph."""
54
- raise NotImplementedError()
55
-
56
- def init_forward_metadata_replay_cuda_graph(
57
- self, bs: int, req_pool_indices, seq_lens
58
- ):
59
- """Init the metadata for a forward pass for replying a cuda graph."""
60
- raise NotImplementedError()
61
-
62
- def get_cuda_graph_seq_len_fill_value(self):
63
- """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
64
- raise NotImplementedError()
65
-
66
- def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
67
- """Run forward on an attention layer."""
68
- if input_metadata.forward_mode.is_decode():
69
- return self.forward_decode(q, k, v, layer, input_metadata)
70
- else:
71
- return self.forward_extend(q, k, v, layer, input_metadata)
72
-
73
- def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
74
- """Run a forward for decode."""
75
- raise NotImplementedError()
76
-
77
- def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
78
- """Run a forward for extend."""
79
- raise NotImplementedError()
80
-
81
-
82
- class FlashInferAttnBackend(AttentionBackend):
83
- """Flashinfer attention kernels."""
84
-
85
- def __init__(self, model_runner: ModelRunner):
86
- super().__init__()
87
- self.model_runner = model_runner
88
-
89
- local_num_qo_heads = (
90
- model_runner.model_config.num_attention_heads // model_runner.tp_size
91
- )
92
- local_num_kv_heads = model_runner.model_config.get_num_kv_heads(
93
- model_runner.tp_size
94
- )
95
- if (
96
- not _grouped_size_compiled_for_decode_kernels(
97
- local_num_qo_heads, local_num_kv_heads
98
- )
99
- or local_num_qo_heads // local_num_kv_heads > 4
100
- ):
101
- self.decode_use_tensor_cores = True
102
- else:
103
- self.decode_use_tensor_cores = False
104
-
105
- self.workspace_buffer = torch.empty(
106
- global_config.flashinfer_workspace_size,
107
- dtype=torch.uint8,
108
- device="cuda",
109
- )
110
-
111
- if model_runner.sliding_window_size is None:
112
- self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
113
- self.workspace_buffer, "NHD"
114
- )
115
- self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
116
- self.workspace_buffer, "NHD"
117
- )
118
- self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
119
- self.workspace_buffer,
120
- "NHD",
121
- use_tensor_cores=self.decode_use_tensor_cores,
122
- )
123
- else:
124
- # Two wrappers: one for sliding window attention and one for full attention.
125
- # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
126
- self.prefill_wrapper_ragged = None
127
- self.prefill_wrapper_paged = []
128
- self.decode_wrapper = []
129
- for _ in range(2):
130
- self.prefill_wrapper_paged.append(
131
- BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
132
- )
133
- self.decode_wrapper.append(
134
- BatchDecodeWithPagedKVCacheWrapper(
135
- self.workspace_buffer,
136
- "NHD",
137
- use_tensor_cores=self.decode_use_tensor_cores,
138
- )
139
- )
140
-
141
- self.forward_metadata = None
142
- self.cuda_graph_metadata = {}
143
-
144
- def init_forward_metadata(
145
- self, batch: ScheduleBatch, input_metadata: InputMetadata
146
- ):
147
- if input_metadata.forward_mode.is_decode():
148
- prefix_lens = None
149
- use_ragged = False
150
- total_num_tokens = None
151
- else:
152
- prefix_lens = input_metadata.extend_prefix_lens
153
-
154
- # Some heuristics to check whether to use ragged forward
155
- use_ragged = False
156
- if (
157
- torch.sum(input_metadata.seq_lens).item() >= 4096
158
- and self.model_runner.sliding_window_size is None
159
- ):
160
- use_ragged = True
161
-
162
- total_num_tokens = torch.sum(input_metadata.seq_lens).item()
163
-
164
- update_flashinfer_indices(
165
- input_metadata.forward_mode,
166
- self.model_runner,
167
- input_metadata.req_pool_indices,
168
- input_metadata.seq_lens,
169
- prefix_lens,
170
- use_ragged=use_ragged,
171
- )
172
-
173
- self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
174
-
175
- def init_cuda_graph_state(self, max_bs: int):
176
- self.cuda_graph_kv_indptr = torch.zeros(
177
- (max_bs + 1,), dtype=torch.int32, device="cuda"
178
- )
179
- self.cuda_graph_kv_indices = torch.zeros(
180
- (max_bs * self.model_runner.model_config.context_len,),
181
- dtype=torch.int32,
182
- device="cuda",
183
- )
184
- self.cuda_graph_kv_last_page_len = torch.ones(
185
- (max_bs,), dtype=torch.int32, device="cuda"
186
- )
187
-
188
- if self.model_runner.sliding_window_size is not None:
189
- self.cuda_graph_kv_indptr = [
190
- self.cuda_graph_kv_indptr,
191
- self.cuda_graph_kv_indptr.clone(),
192
- ]
193
- self.cuda_graph_kv_indices = [
194
- self.cuda_graph_kv_indices,
195
- self.cuda_graph_kv_indices.clone(),
196
- ]
197
-
198
- def init_forward_metadata_capture_cuda_graph(
199
- self, bs: int, req_pool_indices, seq_lens
200
- ):
201
- if self.model_runner.sliding_window_size is None:
202
- decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
203
- self.workspace_buffer,
204
- "NHD",
205
- use_cuda_graph=True,
206
- use_tensor_cores=self.decode_use_tensor_cores,
207
- paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
208
- paged_kv_indices_buffer=self.cuda_graph_kv_indices,
209
- paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
210
- )
211
- else:
212
- decode_wrapper = []
213
- for i in range(2):
214
- decode_wrapper.append(
215
- BatchDecodeWithPagedKVCacheWrapper(
216
- self.workspace_buffer,
217
- "NHD",
218
- use_cuda_graph=True,
219
- use_tensor_cores=self.decode_use_tensor_cores,
220
- paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
221
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
222
- paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
223
- :bs
224
- ],
225
- )
226
- )
227
-
228
- update_flashinfer_indices(
229
- ForwardMode.DECODE,
230
- self.model_runner,
231
- req_pool_indices,
232
- seq_lens,
233
- None,
234
- decode_wrapper,
235
- )
236
-
237
- self.cuda_graph_metadata[bs] = decode_wrapper
238
-
239
- self.forward_metadata = (False, None, decode_wrapper)
240
-
241
- def init_forward_metadata_replay_cuda_graph(
242
- self, bs: int, req_pool_indices, seq_lens
243
- ):
244
- update_flashinfer_indices(
245
- ForwardMode.DECODE,
246
- self.model_runner,
247
- req_pool_indices[:bs],
248
- seq_lens[:bs],
249
- None,
250
- self.cuda_graph_metadata[bs],
251
- )
252
-
253
- def get_cuda_graph_seq_len_fill_value(self):
254
- return 0
255
-
256
- def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
257
- if not isinstance(self.prefill_wrapper_paged, list):
258
- prefill_wrapper_paged = self.prefill_wrapper_paged
259
- else:
260
- if layer.sliding_window_size != -1:
261
- prefill_wrapper_paged = self.prefill_wrapper_paged[0]
262
- else:
263
- prefill_wrapper_paged = self.prefill_wrapper_paged[1]
264
-
265
- use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
266
-
267
- if not use_ragged:
268
- if k is not None:
269
- assert v is not None
270
- input_metadata.token_to_kv_pool.set_kv_buffer(
271
- layer.layer_id, input_metadata.out_cache_loc, k, v
272
- )
273
- o = prefill_wrapper_paged.forward(
274
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
275
- input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
276
- causal=True,
277
- sm_scale=layer.scaling,
278
- window_left=layer.sliding_window_size,
279
- logits_soft_cap=layer.logit_cap,
280
- )
281
- else:
282
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
283
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
284
- k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
285
- v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
286
- causal=True,
287
- sm_scale=layer.scaling,
288
- logits_soft_cap=layer.logit_cap,
289
- )
290
-
291
- if input_metadata.extend_no_prefix:
292
- o = o1
293
- else:
294
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
295
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
296
- input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
297
- causal=False,
298
- sm_scale=layer.scaling,
299
- logits_soft_cap=layer.logit_cap,
300
- )
301
-
302
- o, _ = merge_state(o1, s1, o2, s2)
303
-
304
- input_metadata.token_to_kv_pool.set_kv_buffer(
305
- layer.layer_id, input_metadata.out_cache_loc, k, v
306
- )
307
-
308
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
309
-
310
- def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
311
- use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
312
-
313
- if isinstance(decode_wrapper, list):
314
- if layer.sliding_window_size != -1:
315
- decode_wrapper = decode_wrapper[0]
316
- else:
317
- decode_wrapper = decode_wrapper[1]
318
-
319
- if k is not None:
320
- assert v is not None
321
- input_metadata.token_to_kv_pool.set_kv_buffer(
322
- layer.layer_id, input_metadata.out_cache_loc, k, v
323
- )
324
-
325
- o = decode_wrapper.forward(
326
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
327
- input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
328
- sm_scale=layer.scaling,
329
- logits_soft_cap=layer.logit_cap,
330
- )
331
-
332
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
333
-
334
-
335
- class TritonAttnBackend(AttentionBackend):
336
- def __init__(self, model_runner: ModelRunner):
337
- # Lazy import to avoid the initialization of cuda context
338
- from sglang.srt.layers.triton_attention.decode_attention import (
339
- decode_attention_fwd,
340
- )
341
- from sglang.srt.layers.triton_attention.extend_attention import (
342
- extend_attention_fwd,
343
- )
344
-
345
- super().__init__()
346
-
347
- self.decode_attention_fwd = decode_attention_fwd
348
- self.extend_attention_fwd = extend_attention_fwd
349
- self.num_head = (
350
- model_runner.model_config.num_attention_heads // model_runner.tp_size
351
- )
352
-
353
- if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
354
- self.reduce_dtype = torch.float32
355
- else:
356
- self.reduce_dtype = torch.float16
357
-
358
- self.forward_metadata = None
359
-
360
- self.cuda_graph_max_seq_len = model_runner.model_config.context_len
361
-
362
- def init_forward_metadata(
363
- self, batch: ScheduleBatch, input_metadata: InputMetadata
364
- ):
365
- """Init auxiliary variables for triton attention backend."""
366
-
367
- if input_metadata.forward_mode.is_decode():
368
- start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
369
- start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
370
-
371
- total_num_tokens = torch.sum(input_metadata.seq_lens).item()
372
- attn_logits = torch.empty(
373
- (self.num_head, total_num_tokens),
374
- dtype=self.reduce_dtype,
375
- device="cuda",
376
- )
377
-
378
- max_seq_len = torch.max(input_metadata.seq_lens).item()
379
- max_extend_len = None
380
- else:
381
- start_loc = attn_logits = max_seq_len = None
382
- prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
383
- max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
384
-
385
- self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
386
-
387
- def init_cuda_graph_state(self, max_bs: int):
388
- self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
389
-
390
- self.cuda_graph_start_loc = torch.zeros(
391
- (max_bs,), dtype=torch.int32, device="cuda"
392
- )
393
- self.cuda_graph_attn_logits = torch.empty(
394
- (
395
- self.num_head,
396
- self.cuda_graph_max_total_num_tokens,
397
- ),
398
- dtype=self.reduce_dtype,
399
- device="cuda",
400
- )
401
-
402
- def init_forward_metadata_capture_cuda_graph(
403
- self, bs: int, req_pool_indices, seq_lens
404
- ):
405
- self.forward_metadata = (
406
- self.cuda_graph_start_loc,
407
- self.cuda_graph_attn_logits,
408
- self.cuda_graph_max_seq_len,
409
- None,
410
- )
411
-
412
- def init_forward_metadata_replay_cuda_graph(
413
- self, bs: int, req_pool_indices, seq_lens
414
- ):
415
- self.cuda_graph_start_loc.zero_()
416
- self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
417
-
418
- def get_cuda_graph_seq_len_fill_value(self):
419
- return 1
420
-
421
- def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
422
- # TODO: reuse the buffer across layers
423
- if layer.qk_head_dim != layer.v_head_dim:
424
- o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
425
- else:
426
- o = torch.empty_like(q)
427
-
428
- input_metadata.token_to_kv_pool.set_kv_buffer(
429
- layer.layer_id, input_metadata.out_cache_loc, k, v
430
- )
431
-
432
- start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
433
- self.extend_attention_fwd(
434
- q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
435
- k.contiguous(),
436
- v.contiguous(),
437
- o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
438
- input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
439
- input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
440
- input_metadata.req_to_token_pool.req_to_token,
441
- input_metadata.req_pool_indices,
442
- input_metadata.seq_lens,
443
- input_metadata.extend_seq_lens,
444
- input_metadata.extend_start_loc,
445
- max_extend_len,
446
- layer.scaling,
447
- layer.logit_cap,
448
- )
449
- return o
450
-
451
- def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
452
- # During torch.compile, there is a bug in rotary_emb that causes the
453
- # output value to have a 3D tensor shape. This reshapes the output correctly.
454
- q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
455
-
456
- # TODO: reuse the buffer across layers
457
- if layer.qk_head_dim != layer.v_head_dim:
458
- o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
459
- else:
460
- o = torch.empty_like(q)
461
-
462
- start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
463
-
464
- input_metadata.token_to_kv_pool.set_kv_buffer(
465
- layer.layer_id, input_metadata.out_cache_loc, k, v
466
- )
467
-
468
- self.decode_attention_fwd(
469
- q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
470
- input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
471
- input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
472
- o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
473
- input_metadata.req_to_token_pool.req_to_token,
474
- input_metadata.req_pool_indices,
475
- start_loc,
476
- input_metadata.seq_lens,
477
- attn_logits,
478
- max_seq_len,
479
- layer.scaling,
480
- layer.logit_cap,
481
- )
482
- return o
@@ -1,207 +0,0 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
16
- """
17
- A controller that manages multiple data parallel workers.
18
- Each data parallel worker can manage multiple tensor parallel workers.
19
- """
20
-
21
- import dataclasses
22
- import logging
23
- import multiprocessing
24
- from enum import Enum, auto
25
-
26
- import numpy as np
27
- import zmq
28
-
29
- from sglang.srt.managers.controller_single import (
30
- start_controller_process as start_controller_process_single,
31
- )
32
- from sglang.srt.managers.io_struct import (
33
- AbortReq,
34
- FlushCacheReq,
35
- TokenizedGenerateReqInput,
36
- )
37
- from sglang.srt.server_args import PortArgs, ServerArgs
38
- from sglang.srt.utils import configure_logger, kill_parent_process
39
- from sglang.utils import get_exception_traceback
40
-
41
- logger = logging.getLogger(__name__)
42
-
43
-
44
- class LoadBalanceMethod(Enum):
45
- """Load balance method."""
46
-
47
- ROUND_ROBIN = auto()
48
- SHORTEST_QUEUE = auto()
49
-
50
- @classmethod
51
- def from_str(cls, method: str):
52
- method = method.upper()
53
- try:
54
- return cls[method]
55
- except KeyError as exc:
56
- raise ValueError(f"Invalid load balance method: {method}") from exc
57
-
58
-
59
- @dataclasses.dataclass
60
- class WorkerHandle:
61
- """Store the handle of a data parallel worker."""
62
-
63
- proc: multiprocessing.Process
64
- queue: multiprocessing.Queue
65
-
66
-
67
- class ControllerMulti:
68
- """A controller that manages multiple data parallel workers."""
69
-
70
- def __init__(
71
- self,
72
- server_args: ServerArgs,
73
- port_args: PortArgs,
74
- ):
75
- # Parse args
76
- self.server_args = server_args
77
- self.port_args = port_args
78
- self.load_balance_method = LoadBalanceMethod.from_str(
79
- server_args.load_balance_method
80
- )
81
-
82
- # Init communication
83
- context = zmq.Context()
84
- self.recv_from_tokenizer = context.socket(zmq.PULL)
85
- self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
86
-
87
- # Dispatch method
88
- self.round_robin_counter = 0
89
- dispatch_lookup = {
90
- LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
91
- LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
92
- }
93
- self.dispatching = dispatch_lookup[self.load_balance_method]
94
-
95
- # Start data parallel workers
96
- self.workers = []
97
- for i in range(server_args.dp_size):
98
- self.start_dp_worker(i)
99
-
100
- def start_dp_worker(self, dp_worker_id: int):
101
- tp_size = self.server_args.tp_size
102
-
103
- pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
104
- duplex=False
105
- )
106
-
107
- gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
108
- queue = multiprocessing.Queue()
109
- proc = multiprocessing.Process(
110
- target=start_controller_process_single,
111
- args=(
112
- self.server_args,
113
- self.port_args,
114
- pipe_controller_writer,
115
- True,
116
- gpu_ids,
117
- dp_worker_id,
118
- queue,
119
- ),
120
- )
121
- proc.start()
122
-
123
- controller_init_state = pipe_controller_reader.recv()
124
- if controller_init_state != "init ok":
125
- raise RuntimeError(
126
- f"Initialization failed. controller_init_state: {controller_init_state}"
127
- )
128
- self.workers.append(
129
- WorkerHandle(
130
- proc=proc,
131
- queue=queue,
132
- )
133
- )
134
-
135
- def round_robin_scheduler(self, input_requests):
136
- for r in input_requests:
137
- self.workers[self.round_robin_counter].queue.put(r)
138
- self.round_robin_counter = (self.round_robin_counter + 1) % len(
139
- self.workers
140
- )
141
-
142
- def shortest_queue_scheduler(self, input_requests):
143
- for r in input_requests:
144
- queue_sizes = [worker.queue.qsize() for worker in self.workers]
145
- wid = np.argmin(queue_sizes)
146
- self.workers[wid].queue.put(r)
147
-
148
- def loop_for_forward(self):
149
- while True:
150
- recv_reqs = self.recv_requests()
151
- self.dispatching(recv_reqs)
152
-
153
- def recv_requests(self):
154
- recv_reqs = []
155
-
156
- while True:
157
- try:
158
- recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
159
- except zmq.ZMQError:
160
- break
161
-
162
- if isinstance(recv_req, FlushCacheReq):
163
- # TODO(lsyin): apply more specific flushCacheReq
164
- for worker in self.workers:
165
- worker.queue.put(recv_req)
166
- elif isinstance(recv_req, AbortReq):
167
- in_queue = False
168
- for i, req in enumerate(recv_reqs):
169
- if req.rid == recv_req.rid:
170
- recv_reqs[i] = recv_req
171
- in_queue = True
172
- break
173
- if not in_queue:
174
- # Send abort req to all TP groups
175
- for worker in self.workers:
176
- worker.queue.put(recv_req)
177
- elif isinstance(recv_req, TokenizedGenerateReqInput):
178
- recv_reqs.append(recv_req)
179
- else:
180
- logger.error(f"Invalid object: {recv_req}")
181
-
182
- return recv_reqs
183
-
184
-
185
- def start_controller_process(
186
- server_args: ServerArgs,
187
- port_args: PortArgs,
188
- pipe_writer,
189
- ):
190
- """Start a controller process."""
191
-
192
- configure_logger(server_args)
193
-
194
- try:
195
- controller = ControllerMulti(server_args, port_args)
196
- except Exception:
197
- pipe_writer.send(get_exception_traceback())
198
- raise
199
-
200
- pipe_writer.send("init ok")
201
-
202
- try:
203
- controller.loop_for_forward()
204
- except Exception:
205
- logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
206
- finally:
207
- kill_parent_process()