sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -1,474 +0,0 @@
1
- from __future__ import annotations
2
-
3
- """
4
- Support different attention backends.
5
- Now there are two backends: FlashInfer and Triton.
6
- FlashInfer is faster and Triton is easier to customize.
7
- Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
- """
9
-
10
- from abc import ABC, abstractmethod
11
- from typing import TYPE_CHECKING
12
-
13
- import torch
14
- import torch.nn as nn
15
-
16
- from sglang.global_config import global_config
17
- from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
18
- from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
19
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
20
- from sglang.srt.utils import is_hip
21
-
22
- if TYPE_CHECKING:
23
- from sglang.srt.model_executor.model_runner import ModelRunner
24
-
25
- # ROCm: flashinfer available later
26
- if not is_hip():
27
- from flashinfer import (
28
- BatchDecodeWithPagedKVCacheWrapper,
29
- BatchPrefillWithPagedKVCacheWrapper,
30
- BatchPrefillWithRaggedKVCacheWrapper,
31
- )
32
- from flashinfer.cascade import merge_state
33
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
-
35
-
36
- class AttentionBackend(ABC):
37
- """The base class of attention backends"""
38
-
39
- @abstractmethod
40
- def init_forward_metadata(
41
- self, batch: ScheduleBatch, input_metadata: InputMetadata
42
- ):
43
- """Init the metadata for a forward pass."""
44
- raise NotImplementedError()
45
-
46
- def init_cuda_graph_state(self, max_bs: int):
47
- """Init the global shared states for cuda graph."""
48
- raise NotImplementedError()
49
-
50
- def init_forward_metadata_capture_cuda_graph(
51
- self, bs: int, req_pool_indices, seq_lens
52
- ):
53
- """Init the metadata for a forward pass for capturing a cuda graph."""
54
- raise NotImplementedError()
55
-
56
- def init_forward_metadata_replay_cuda_graph(
57
- self, bs: int, req_pool_indices, seq_lens
58
- ):
59
- """Init the metadata for a forward pass for replying a cuda graph."""
60
- raise NotImplementedError()
61
-
62
- def get_cuda_graph_seq_len_fill_value(self):
63
- """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
64
- raise NotImplementedError()
65
-
66
- def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
67
- """Run forward on an attention layer."""
68
- if input_metadata.forward_mode.is_decode():
69
- return self.forward_decode(q, k, v, layer, input_metadata)
70
- else:
71
- return self.forward_extend(q, k, v, layer, input_metadata)
72
-
73
- def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
74
- """Run a forward for decode."""
75
- raise NotImplementedError()
76
-
77
- def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
78
- """Run a forward for extend."""
79
- raise NotImplementedError()
80
-
81
-
82
- class FlashInferAttnBackend(AttentionBackend):
83
- """Flashinfer attention kernels."""
84
-
85
- def __init__(self, model_runner: ModelRunner):
86
- super().__init__()
87
- self.model_runner = model_runner
88
-
89
- if not _grouped_size_compiled_for_decode_kernels(
90
- model_runner.model_config.num_attention_heads // model_runner.tp_size,
91
- model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
92
- ):
93
- self.decode_use_tensor_cores = True
94
- else:
95
- self.decode_use_tensor_cores = False
96
-
97
- self.workspace_buffer = torch.empty(
98
- global_config.flashinfer_workspace_size,
99
- dtype=torch.uint8,
100
- device="cuda",
101
- )
102
-
103
- if model_runner.sliding_window_size is None:
104
- self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
105
- self.workspace_buffer, "NHD"
106
- )
107
- self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
108
- self.workspace_buffer, "NHD"
109
- )
110
- self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
111
- self.workspace_buffer,
112
- "NHD",
113
- use_tensor_cores=self.decode_use_tensor_cores,
114
- )
115
- else:
116
- # Two wrappers: one for sliding window attention and one for full attention.
117
- # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
118
- self.prefill_wrapper_ragged = None
119
- self.prefill_wrapper_paged = []
120
- self.decode_wrapper = []
121
- for _ in range(2):
122
- self.prefill_wrapper_paged.append(
123
- BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
124
- )
125
- self.decode_wrapper.append(
126
- BatchDecodeWithPagedKVCacheWrapper(
127
- self.workspace_buffer,
128
- "NHD",
129
- use_tensor_cores=self.decode_use_tensor_cores,
130
- )
131
- )
132
-
133
- self.forward_metadata = None
134
- self.cuda_graph_metadata = {}
135
-
136
- def init_forward_metadata(
137
- self, batch: ScheduleBatch, input_metadata: InputMetadata
138
- ):
139
- if input_metadata.forward_mode.is_decode():
140
- prefix_lens = None
141
- use_ragged = False
142
- total_num_tokens = None
143
- else:
144
- prefix_lens = input_metadata.extend_prefix_lens
145
-
146
- # Some heuristics to check whether to use ragged forward
147
- use_ragged = False
148
- if (
149
- torch.sum(input_metadata.seq_lens).item() >= 4096
150
- and self.model_runner.sliding_window_size is None
151
- ):
152
- use_ragged = True
153
-
154
- total_num_tokens = torch.sum(input_metadata.seq_lens).item()
155
-
156
- update_flashinfer_indices(
157
- input_metadata.forward_mode,
158
- self.model_runner,
159
- input_metadata.req_pool_indices,
160
- input_metadata.seq_lens,
161
- prefix_lens,
162
- use_ragged=use_ragged,
163
- )
164
-
165
- self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
166
-
167
- def init_cuda_graph_state(self, max_bs: int):
168
- self.cuda_graph_kv_indptr = torch.zeros(
169
- (max_bs + 1,), dtype=torch.int32, device="cuda"
170
- )
171
- self.cuda_graph_kv_indices = torch.zeros(
172
- (max_bs * self.model_runner.model_config.context_len,),
173
- dtype=torch.int32,
174
- device="cuda",
175
- )
176
- self.cuda_graph_kv_last_page_len = torch.ones(
177
- (max_bs,), dtype=torch.int32, device="cuda"
178
- )
179
-
180
- if self.model_runner.sliding_window_size is not None:
181
- self.cuda_graph_kv_indptr = [
182
- self.cuda_graph_kv_indptr,
183
- self.cuda_graph_kv_indptr.clone(),
184
- ]
185
- self.cuda_graph_kv_indices = [
186
- self.cuda_graph_kv_indices,
187
- self.cuda_graph_kv_indices.clone(),
188
- ]
189
-
190
- def init_forward_metadata_capture_cuda_graph(
191
- self, bs: int, req_pool_indices, seq_lens
192
- ):
193
- if self.model_runner.sliding_window_size is None:
194
- decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
195
- self.workspace_buffer,
196
- "NHD",
197
- use_cuda_graph=True,
198
- use_tensor_cores=self.decode_use_tensor_cores,
199
- paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
200
- paged_kv_indices_buffer=self.cuda_graph_kv_indices,
201
- paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
202
- )
203
- else:
204
- decode_wrapper = []
205
- for i in range(2):
206
- decode_wrapper.append(
207
- BatchDecodeWithPagedKVCacheWrapper(
208
- self.workspace_buffer,
209
- "NHD",
210
- use_cuda_graph=True,
211
- use_tensor_cores=self.decode_use_tensor_cores,
212
- paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
213
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
214
- paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
215
- :bs
216
- ],
217
- )
218
- )
219
-
220
- update_flashinfer_indices(
221
- ForwardMode.DECODE,
222
- self.model_runner,
223
- req_pool_indices,
224
- seq_lens,
225
- None,
226
- decode_wrapper,
227
- )
228
-
229
- self.cuda_graph_metadata[bs] = decode_wrapper
230
-
231
- self.forward_metadata = (False, None, decode_wrapper)
232
-
233
- def init_forward_metadata_replay_cuda_graph(
234
- self, bs: int, req_pool_indices, seq_lens
235
- ):
236
- update_flashinfer_indices(
237
- ForwardMode.DECODE,
238
- self.model_runner,
239
- req_pool_indices[:bs],
240
- seq_lens[:bs],
241
- None,
242
- self.cuda_graph_metadata[bs],
243
- )
244
-
245
- def get_cuda_graph_seq_len_fill_value(self):
246
- return 0
247
-
248
- def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
249
- if not isinstance(self.prefill_wrapper_paged, list):
250
- prefill_wrapper_paged = self.prefill_wrapper_paged
251
- else:
252
- if layer.sliding_window_size != -1:
253
- prefill_wrapper_paged = self.prefill_wrapper_paged[0]
254
- else:
255
- prefill_wrapper_paged = self.prefill_wrapper_paged[1]
256
-
257
- use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
258
-
259
- if not use_ragged:
260
- if k is not None:
261
- assert v is not None
262
- input_metadata.token_to_kv_pool.set_kv_buffer(
263
- layer.layer_id, input_metadata.out_cache_loc, k, v
264
- )
265
- o = prefill_wrapper_paged.forward(
266
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
267
- input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
268
- causal=True,
269
- sm_scale=layer.scaling,
270
- window_left=layer.sliding_window_size,
271
- logits_soft_cap=layer.logit_cap,
272
- )
273
- else:
274
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
275
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
276
- k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
277
- v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
278
- causal=True,
279
- sm_scale=layer.scaling,
280
- logits_soft_cap=layer.logit_cap,
281
- )
282
-
283
- if input_metadata.extend_no_prefix:
284
- o = o1
285
- else:
286
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
287
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
288
- input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
289
- causal=False,
290
- sm_scale=layer.scaling,
291
- logits_soft_cap=layer.logit_cap,
292
- )
293
-
294
- o, _ = merge_state(o1, s1, o2, s2)
295
-
296
- input_metadata.token_to_kv_pool.set_kv_buffer(
297
- layer.layer_id, input_metadata.out_cache_loc, k, v
298
- )
299
-
300
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
301
-
302
- def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
303
- use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
304
-
305
- if isinstance(decode_wrapper, list):
306
- if layer.sliding_window_size != -1:
307
- decode_wrapper = decode_wrapper[0]
308
- else:
309
- decode_wrapper = decode_wrapper[1]
310
-
311
- if k is not None:
312
- assert v is not None
313
- input_metadata.token_to_kv_pool.set_kv_buffer(
314
- layer.layer_id, input_metadata.out_cache_loc, k, v
315
- )
316
-
317
- o = decode_wrapper.forward(
318
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
319
- input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
320
- sm_scale=layer.scaling,
321
- logits_soft_cap=layer.logit_cap,
322
- )
323
-
324
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
325
-
326
-
327
- class TritonAttnBackend(AttentionBackend):
328
- def __init__(self, model_runner: ModelRunner):
329
- # Lazy import to avoid the initialization of cuda context
330
- from sglang.srt.layers.triton_attention.decode_attention import (
331
- decode_attention_fwd,
332
- )
333
- from sglang.srt.layers.triton_attention.extend_attention import (
334
- extend_attention_fwd,
335
- )
336
-
337
- super().__init__()
338
-
339
- self.decode_attention_fwd = decode_attention_fwd
340
- self.extend_attention_fwd = extend_attention_fwd
341
- self.num_head = (
342
- model_runner.model_config.num_attention_heads // model_runner.tp_size
343
- )
344
-
345
- if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
346
- self.reduce_dtype = torch.float32
347
- else:
348
- self.reduce_dtype = torch.float16
349
-
350
- self.forward_metadata = None
351
-
352
- self.cuda_graph_max_seq_len = model_runner.model_config.context_len
353
-
354
- def init_forward_metadata(
355
- self, batch: ScheduleBatch, input_metadata: InputMetadata
356
- ):
357
- """Init auxiliary variables for triton attention backend."""
358
-
359
- if input_metadata.forward_mode.is_decode():
360
- start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
361
- start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
362
-
363
- total_num_tokens = torch.sum(input_metadata.seq_lens).item()
364
- attn_logits = torch.empty(
365
- (self.num_head, total_num_tokens),
366
- dtype=self.reduce_dtype,
367
- device="cuda",
368
- )
369
-
370
- max_seq_len = torch.max(input_metadata.seq_lens).item()
371
- max_extend_len = None
372
- else:
373
- start_loc = attn_logits = max_seq_len = None
374
- prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
375
- max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
376
-
377
- self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
378
-
379
- def init_cuda_graph_state(self, max_bs: int):
380
- self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
381
-
382
- self.cuda_graph_start_loc = torch.zeros(
383
- (max_bs,), dtype=torch.int32, device="cuda"
384
- )
385
- self.cuda_graph_attn_logits = torch.empty(
386
- (
387
- self.num_head,
388
- self.cuda_graph_max_total_num_tokens,
389
- ),
390
- dtype=self.reduce_dtype,
391
- device="cuda",
392
- )
393
-
394
- def init_forward_metadata_capture_cuda_graph(
395
- self, bs: int, req_pool_indices, seq_lens
396
- ):
397
- self.forward_metadata = (
398
- self.cuda_graph_start_loc,
399
- self.cuda_graph_attn_logits,
400
- self.cuda_graph_max_seq_len,
401
- None,
402
- )
403
-
404
- def init_forward_metadata_replay_cuda_graph(
405
- self, bs: int, req_pool_indices, seq_lens
406
- ):
407
- self.cuda_graph_start_loc.zero_()
408
- self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
409
-
410
- def get_cuda_graph_seq_len_fill_value(self):
411
- return 1
412
-
413
- def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
414
- # TODO: reuse the buffer across layers
415
- if layer.qk_head_dim != layer.v_head_dim:
416
- o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
417
- else:
418
- o = torch.empty_like(q)
419
-
420
- input_metadata.token_to_kv_pool.set_kv_buffer(
421
- layer.layer_id, input_metadata.out_cache_loc, k, v
422
- )
423
-
424
- start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
425
- self.extend_attention_fwd(
426
- q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
427
- k.contiguous(),
428
- v.contiguous(),
429
- o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
430
- input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
431
- input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
432
- input_metadata.req_to_token_pool.req_to_token,
433
- input_metadata.req_pool_indices,
434
- input_metadata.seq_lens,
435
- input_metadata.extend_seq_lens,
436
- input_metadata.extend_start_loc,
437
- max_extend_len,
438
- layer.scaling,
439
- layer.logit_cap,
440
- )
441
- return o
442
-
443
- def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
444
- # During torch.compile, there is a bug in rotary_emb that causes the
445
- # output value to have a 3D tensor shape. This reshapes the output correctly.
446
- q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
447
-
448
- # TODO: reuse the buffer across layers
449
- if layer.qk_head_dim != layer.v_head_dim:
450
- o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
451
- else:
452
- o = torch.empty_like(q)
453
-
454
- start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
455
-
456
- input_metadata.token_to_kv_pool.set_kv_buffer(
457
- layer.layer_id, input_metadata.out_cache_loc, k, v
458
- )
459
-
460
- self.decode_attention_fwd(
461
- q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
462
- input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
463
- input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
464
- o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
465
- input_metadata.req_to_token_pool.req_to_token,
466
- input_metadata.req_pool_indices,
467
- start_loc,
468
- input_metadata.seq_lens,
469
- attn_logits,
470
- max_seq_len,
471
- layer.scaling,
472
- layer.logit_cap,
473
- )
474
- return o
@@ -1,207 +0,0 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
16
- """
17
- A controller that manages multiple data parallel workers.
18
- Each data parallel worker can manage multiple tensor parallel workers.
19
- """
20
-
21
- import dataclasses
22
- import logging
23
- import multiprocessing
24
- from enum import Enum, auto
25
-
26
- import numpy as np
27
- import zmq
28
-
29
- from sglang.srt.managers.controller_single import (
30
- start_controller_process as start_controller_process_single,
31
- )
32
- from sglang.srt.managers.io_struct import (
33
- AbortReq,
34
- FlushCacheReq,
35
- TokenizedGenerateReqInput,
36
- )
37
- from sglang.srt.server_args import PortArgs, ServerArgs
38
- from sglang.srt.utils import configure_logger, kill_parent_process
39
- from sglang.utils import get_exception_traceback
40
-
41
- logger = logging.getLogger(__name__)
42
-
43
-
44
- class LoadBalanceMethod(Enum):
45
- """Load balance method."""
46
-
47
- ROUND_ROBIN = auto()
48
- SHORTEST_QUEUE = auto()
49
-
50
- @classmethod
51
- def from_str(cls, method: str):
52
- method = method.upper()
53
- try:
54
- return cls[method]
55
- except KeyError as exc:
56
- raise ValueError(f"Invalid load balance method: {method}") from exc
57
-
58
-
59
- @dataclasses.dataclass
60
- class WorkerHandle:
61
- """Store the handle of a data parallel worker."""
62
-
63
- proc: multiprocessing.Process
64
- queue: multiprocessing.Queue
65
-
66
-
67
- class ControllerMulti:
68
- """A controller that manages multiple data parallel workers."""
69
-
70
- def __init__(
71
- self,
72
- server_args: ServerArgs,
73
- port_args: PortArgs,
74
- ):
75
- # Parse args
76
- self.server_args = server_args
77
- self.port_args = port_args
78
- self.load_balance_method = LoadBalanceMethod.from_str(
79
- server_args.load_balance_method
80
- )
81
-
82
- # Init communication
83
- context = zmq.Context()
84
- self.recv_from_tokenizer = context.socket(zmq.PULL)
85
- self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
86
-
87
- # Dispatch method
88
- self.round_robin_counter = 0
89
- dispatch_lookup = {
90
- LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
91
- LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
92
- }
93
- self.dispatching = dispatch_lookup[self.load_balance_method]
94
-
95
- # Start data parallel workers
96
- self.workers = []
97
- for i in range(server_args.dp_size):
98
- self.start_dp_worker(i)
99
-
100
- def start_dp_worker(self, dp_worker_id: int):
101
- tp_size = self.server_args.tp_size
102
-
103
- pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
104
- duplex=False
105
- )
106
-
107
- gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
108
- queue = multiprocessing.Queue()
109
- proc = multiprocessing.Process(
110
- target=start_controller_process_single,
111
- args=(
112
- self.server_args,
113
- self.port_args,
114
- pipe_controller_writer,
115
- True,
116
- gpu_ids,
117
- dp_worker_id,
118
- queue,
119
- ),
120
- )
121
- proc.start()
122
-
123
- controller_init_state = pipe_controller_reader.recv()
124
- if controller_init_state != "init ok":
125
- raise RuntimeError(
126
- f"Initialization failed. controller_init_state: {controller_init_state}"
127
- )
128
- self.workers.append(
129
- WorkerHandle(
130
- proc=proc,
131
- queue=queue,
132
- )
133
- )
134
-
135
- def round_robin_scheduler(self, input_requests):
136
- for r in input_requests:
137
- self.workers[self.round_robin_counter].queue.put(r)
138
- self.round_robin_counter = (self.round_robin_counter + 1) % len(
139
- self.workers
140
- )
141
-
142
- def shortest_queue_scheduler(self, input_requests):
143
- for r in input_requests:
144
- queue_sizes = [worker.queue.qsize() for worker in self.workers]
145
- wid = np.argmin(queue_sizes)
146
- self.workers[wid].queue.put(r)
147
-
148
- def loop_for_forward(self):
149
- while True:
150
- recv_reqs = self.recv_requests()
151
- self.dispatching(recv_reqs)
152
-
153
- def recv_requests(self):
154
- recv_reqs = []
155
-
156
- while True:
157
- try:
158
- recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
159
- except zmq.ZMQError:
160
- break
161
-
162
- if isinstance(recv_req, FlushCacheReq):
163
- # TODO(lsyin): apply more specific flushCacheReq
164
- for worker in self.workers:
165
- worker.queue.put(recv_req)
166
- elif isinstance(recv_req, AbortReq):
167
- in_queue = False
168
- for i, req in enumerate(recv_reqs):
169
- if req.rid == recv_req.rid:
170
- recv_reqs[i] = recv_req
171
- in_queue = True
172
- break
173
- if not in_queue:
174
- # Send abort req to all TP groups
175
- for worker in self.workers:
176
- worker.queue.put(recv_req)
177
- elif isinstance(recv_req, TokenizedGenerateReqInput):
178
- recv_reqs.append(recv_req)
179
- else:
180
- logger.error(f"Invalid object: {recv_req}")
181
-
182
- return recv_reqs
183
-
184
-
185
- def start_controller_process(
186
- server_args: ServerArgs,
187
- port_args: PortArgs,
188
- pipe_writer,
189
- ):
190
- """Start a controller process."""
191
-
192
- configure_logger(server_args)
193
-
194
- try:
195
- controller = ControllerMulti(server_args, port_args)
196
- except Exception:
197
- pipe_writer.send(get_exception_traceback())
198
- raise
199
-
200
- pipe_writer.send("init ok")
201
-
202
- try:
203
- controller.loop_for_forward()
204
- except Exception:
205
- logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
206
- finally:
207
- kill_parent_process()