sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +16 -7
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +21 -5
- sglang/srt/layers/linear.py +89 -47
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +439 -0
- sglang/srt/layers/quantization/__init__.py +5 -2
- sglang/srt/layers/quantization/fp8.py +107 -53
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +16 -3
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +58 -15
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +109 -45
- sglang/srt/mem_cache/memory_pool.py +313 -53
- sglang/srt/metrics/collector.py +32 -35
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +53 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +46 -4
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +15 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +125 -69
- sglang/srt/server_args.py +39 -19
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +48 -33
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +61 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
17
|
+
|
16
18
|
"""
|
17
19
|
Memory pool.
|
18
20
|
|
@@ -22,26 +24,45 @@ BaseTokenToKVPool maps a token location to its KV cache data.
|
|
22
24
|
"""
|
23
25
|
|
24
26
|
import logging
|
27
|
+
import threading
|
28
|
+
from enum import IntEnum
|
29
|
+
from functools import wraps
|
25
30
|
from typing import List, Tuple, Union
|
26
31
|
|
32
|
+
import numpy as np
|
33
|
+
import psutil
|
27
34
|
import torch
|
28
35
|
|
29
36
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
-
from sglang.srt.utils import get_compiler_backend
|
37
|
+
from sglang.srt.utils import debug_timing, get_compiler_backend
|
31
38
|
|
32
39
|
logger = logging.getLogger(__name__)
|
33
40
|
|
41
|
+
GB = 1024 * 1024 * 1024
|
42
|
+
|
34
43
|
|
35
44
|
class ReqToTokenPool:
|
36
45
|
"""A memory pool that maps a request to its token locations."""
|
37
46
|
|
38
|
-
def __init__(
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
size: int,
|
50
|
+
max_context_len: int,
|
51
|
+
device: str,
|
52
|
+
use_records: bool,
|
53
|
+
enable_memory_saver: bool,
|
54
|
+
):
|
55
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
56
|
+
enable=enable_memory_saver
|
57
|
+
)
|
58
|
+
|
39
59
|
self.size = size
|
40
60
|
self.max_context_len = max_context_len
|
41
61
|
self.device = device
|
42
|
-
|
43
|
-
|
44
|
-
|
62
|
+
with memory_saver_adapter.region():
|
63
|
+
self.req_to_token = torch.zeros(
|
64
|
+
(size, max_context_len), dtype=torch.int32, device=device
|
65
|
+
)
|
45
66
|
self.free_slots = list(range(size))
|
46
67
|
self.write_records = []
|
47
68
|
self.use_records = use_records
|
@@ -105,8 +126,8 @@ class BaseTokenToKVPool:
|
|
105
126
|
):
|
106
127
|
self.size = size
|
107
128
|
self.dtype = dtype
|
108
|
-
if dtype
|
109
|
-
# NOTE: Store as torch.uint8 because Tensor
|
129
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
130
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
110
131
|
self.store_dtype = torch.uint8
|
111
132
|
else:
|
112
133
|
self.store_dtype = dtype
|
@@ -182,37 +203,80 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
182
203
|
head_dim: int,
|
183
204
|
layer_num: int,
|
184
205
|
device: str,
|
206
|
+
enable_memory_saver: bool,
|
185
207
|
):
|
186
208
|
super().__init__(size, dtype, device)
|
209
|
+
|
210
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
211
|
+
enable=enable_memory_saver
|
212
|
+
)
|
213
|
+
|
187
214
|
self.head_num = head_num
|
188
215
|
self.head_dim = head_dim
|
189
216
|
self.layer_num = layer_num
|
190
217
|
self._create_buffers()
|
191
218
|
|
219
|
+
k_size, v_size = self.get_kv_size_bytes()
|
220
|
+
logger.info(
|
221
|
+
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
|
222
|
+
)
|
223
|
+
|
192
224
|
def _create_buffers(self):
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
(
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
(
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
225
|
+
with self.memory_saver_adapter.region():
|
226
|
+
# [size, head_num, head_dim] for each layer
|
227
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
228
|
+
self.k_buffer = [
|
229
|
+
torch.empty(
|
230
|
+
(self.size + 1, self.head_num, self.head_dim),
|
231
|
+
dtype=self.store_dtype,
|
232
|
+
device=self.device,
|
233
|
+
)
|
234
|
+
for _ in range(self.layer_num)
|
235
|
+
]
|
236
|
+
self.v_buffer = [
|
237
|
+
torch.empty(
|
238
|
+
(self.size + 1, self.head_num, self.head_dim),
|
239
|
+
dtype=self.store_dtype,
|
240
|
+
device=self.device,
|
241
|
+
)
|
242
|
+
for _ in range(self.layer_num)
|
243
|
+
]
|
211
244
|
|
212
245
|
def _clear_buffers(self):
|
213
246
|
del self.k_buffer
|
214
247
|
del self.v_buffer
|
215
248
|
|
249
|
+
def get_kv_size_bytes(self):
|
250
|
+
assert hasattr(self, "k_buffer")
|
251
|
+
assert hasattr(self, "v_buffer")
|
252
|
+
k_size_bytes = 0
|
253
|
+
for k_cache in self.k_buffer:
|
254
|
+
k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
255
|
+
v_size_bytes = 0
|
256
|
+
for v_cache in self.v_buffer:
|
257
|
+
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
258
|
+
return k_size_bytes, v_size_bytes
|
259
|
+
|
260
|
+
# Todo: different memory layout
|
261
|
+
def get_flat_data(self, indices):
|
262
|
+
# prepare a large chunk of contiguous data for efficient transfer
|
263
|
+
flatten = torch.stack(
|
264
|
+
[
|
265
|
+
torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
|
266
|
+
torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
|
267
|
+
]
|
268
|
+
)
|
269
|
+
return flatten
|
270
|
+
|
271
|
+
@debug_timing
|
272
|
+
def transfer(self, indices, flat_data):
|
273
|
+
# transfer prepared data from host to device
|
274
|
+
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
275
|
+
k_data, v_data = flat_data[0], flat_data[1]
|
276
|
+
for i in range(self.layer_num):
|
277
|
+
self.k_buffer[i][indices] = k_data[i]
|
278
|
+
self.v_buffer[i][indices] = v_data[i]
|
279
|
+
|
216
280
|
def get_key_buffer(self, layer_id: int):
|
217
281
|
if self.store_dtype != self.dtype:
|
218
282
|
return self.k_buffer[layer_id].view(self.dtype)
|
@@ -232,11 +296,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
232
296
|
loc: torch.Tensor,
|
233
297
|
cache_k: torch.Tensor,
|
234
298
|
cache_v: torch.Tensor,
|
299
|
+
k_scale: float = 1.0,
|
300
|
+
v_scale: float = 1.0,
|
235
301
|
):
|
236
302
|
layer_id = layer.layer_id
|
237
303
|
if cache_k.dtype != self.dtype:
|
238
|
-
cache_k = cache_k.to(self.dtype)
|
239
|
-
cache_v = cache_v.to(self.dtype)
|
304
|
+
cache_k = (cache_k / k_scale).to(self.dtype)
|
305
|
+
cache_v = (cache_v / v_scale).to(self.dtype)
|
240
306
|
if self.store_dtype != self.dtype:
|
241
307
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
242
308
|
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
@@ -262,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
262
328
|
qk_rope_head_dim: int,
|
263
329
|
layer_num: int,
|
264
330
|
device: str,
|
331
|
+
enable_memory_saver: bool,
|
265
332
|
):
|
266
333
|
super().__init__(size, dtype, device)
|
267
334
|
|
268
335
|
self.kv_lora_rank = kv_lora_rank
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
336
|
+
|
337
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
338
|
+
enable=enable_memory_saver
|
339
|
+
)
|
340
|
+
|
341
|
+
with memory_saver_adapter.region():
|
342
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
343
|
+
self.kv_buffer = [
|
344
|
+
torch.empty(
|
345
|
+
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
346
|
+
dtype=self.store_dtype,
|
347
|
+
device=device,
|
348
|
+
)
|
349
|
+
for _ in range(layer_num)
|
350
|
+
]
|
278
351
|
|
279
352
|
def get_key_buffer(self, layer_id: int):
|
280
353
|
if self.store_dtype != self.dtype:
|
@@ -315,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|
315
388
|
layer_num: int,
|
316
389
|
device: str,
|
317
390
|
heavy_channel_num: int,
|
391
|
+
enable_memory_saver: bool,
|
318
392
|
):
|
319
393
|
super().__init__(size, dtype, device)
|
320
394
|
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
for
|
337
|
-
|
395
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
396
|
+
enable=enable_memory_saver
|
397
|
+
)
|
398
|
+
|
399
|
+
with memory_saver_adapter.region():
|
400
|
+
# [size, head_num, head_dim] for each layer
|
401
|
+
self.k_buffer = [
|
402
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
403
|
+
for _ in range(layer_num)
|
404
|
+
]
|
405
|
+
self.v_buffer = [
|
406
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
407
|
+
for _ in range(layer_num)
|
408
|
+
]
|
409
|
+
|
410
|
+
# [size, head_num, heavy_channel_num] for each layer
|
411
|
+
self.label_buffer = [
|
412
|
+
torch.empty(
|
413
|
+
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
414
|
+
)
|
415
|
+
for _ in range(layer_num)
|
416
|
+
]
|
338
417
|
|
339
418
|
def get_key_buffer(self, layer_id: int):
|
340
419
|
return self.k_buffer[layer_id]
|
@@ -361,3 +440,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|
361
440
|
self.k_buffer[layer_id][loc] = cache_k
|
362
441
|
self.v_buffer[layer_id][loc] = cache_v
|
363
442
|
self.label_buffer[layer_id][loc] = cache_label
|
443
|
+
|
444
|
+
|
445
|
+
class MemoryStateInt(IntEnum):
|
446
|
+
IDLE = 0
|
447
|
+
RESERVED = 1
|
448
|
+
PROTECTED = 2
|
449
|
+
SYNCED = 3
|
450
|
+
BACKUP = 4
|
451
|
+
|
452
|
+
|
453
|
+
def synchronized(func):
|
454
|
+
@wraps(func)
|
455
|
+
def wrapper(self, *args, **kwargs):
|
456
|
+
with self.lock:
|
457
|
+
return func(self, *args, **kwargs)
|
458
|
+
|
459
|
+
return wrapper
|
460
|
+
|
461
|
+
|
462
|
+
class MLATokenToKVPoolHost:
|
463
|
+
|
464
|
+
def __init__(
|
465
|
+
self,
|
466
|
+
device_pool: MHATokenToKVPool,
|
467
|
+
host_to_device_ratio: float = 2.0,
|
468
|
+
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
469
|
+
device: str = "cpu",
|
470
|
+
):
|
471
|
+
assert (
|
472
|
+
host_to_device_ratio >= 1
|
473
|
+
), "The host memory should be larger than the device memory with the current protocol"
|
474
|
+
# todo, other ways of configuring the size
|
475
|
+
|
476
|
+
self.device_pool = device_pool
|
477
|
+
self.host_to_device_ratio = host_to_device_ratio
|
478
|
+
self.pin_memory = pin_memory
|
479
|
+
self.device = device
|
480
|
+
|
481
|
+
self.size = int(device_pool.size * host_to_device_ratio)
|
482
|
+
self.dtype = device_pool.store_dtype
|
483
|
+
self.head_num = device_pool.head_num
|
484
|
+
self.head_dim = device_pool.head_dim
|
485
|
+
self.layer_num = device_pool.layer_num
|
486
|
+
self.size_per_token = (
|
487
|
+
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
488
|
+
)
|
489
|
+
|
490
|
+
# Verify there is enough available host memory.
|
491
|
+
host_mem = psutil.virtual_memory()
|
492
|
+
requested_bytes = self.size * self.size_per_token
|
493
|
+
# preserve at least 10GB for other usage
|
494
|
+
ten_gb = 10 * (1024**3)
|
495
|
+
if requested_bytes > host_mem.available - ten_gb:
|
496
|
+
raise ValueError(
|
497
|
+
f"Not enough host memory available. Requesting "
|
498
|
+
f"{requested_bytes / 1e9:.2f} GB but only have "
|
499
|
+
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
|
500
|
+
f"size of the hierarchical cache."
|
501
|
+
)
|
502
|
+
else:
|
503
|
+
logger.info(
|
504
|
+
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
505
|
+
)
|
506
|
+
|
507
|
+
self.kv_buffer = torch.empty(
|
508
|
+
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
509
|
+
dtype=self.dtype,
|
510
|
+
device=self.device,
|
511
|
+
pin_memory=self.pin_memory,
|
512
|
+
)
|
513
|
+
|
514
|
+
# Initialize memory states and tracking structures.
|
515
|
+
self.mem_state = torch.zeros(
|
516
|
+
(self.size,), dtype=torch.uint8, device=self.device
|
517
|
+
)
|
518
|
+
self.free_slots = torch.arange(self.size, dtype=torch.int32)
|
519
|
+
self.can_use_mem_size = self.size
|
520
|
+
|
521
|
+
# A lock for synchronized operations on memory allocation and state transitions.
|
522
|
+
self.lock = threading.RLock()
|
523
|
+
|
524
|
+
def get_flat_data(self, indices):
|
525
|
+
return self.kv_buffer[:, :, indices]
|
526
|
+
|
527
|
+
@debug_timing
|
528
|
+
def transfer(self, indices, flat_data):
|
529
|
+
# backup prepared data from device to host
|
530
|
+
self.kv_buffer[:, :, indices] = flat_data.to(
|
531
|
+
device=self.device, non_blocking=False
|
532
|
+
)
|
533
|
+
|
534
|
+
@synchronized
|
535
|
+
def clear(self):
|
536
|
+
self.mem_state.fill_(0)
|
537
|
+
self.can_use_mem_size = self.size
|
538
|
+
self.free_slots = torch.arange(self.size, dtype=torch.int32)
|
539
|
+
|
540
|
+
@synchronized
|
541
|
+
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
542
|
+
assert len(indices) > 0, "The indices should not be empty"
|
543
|
+
states = self.mem_state[indices]
|
544
|
+
assert (
|
545
|
+
states == states[0]
|
546
|
+
).all(), "The memory slots should have the same state {}".format(states)
|
547
|
+
return MemoryStateInt(states[0].item())
|
548
|
+
|
549
|
+
@synchronized
|
550
|
+
def alloc(self, need_size: int) -> torch.Tensor:
|
551
|
+
if need_size > self.can_use_mem_size:
|
552
|
+
return None
|
553
|
+
|
554
|
+
# todo: de-fragementation
|
555
|
+
select_index = self.free_slots[:need_size]
|
556
|
+
self.free_slots = self.free_slots[need_size:]
|
557
|
+
|
558
|
+
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
559
|
+
self.can_use_mem_size -= need_size
|
560
|
+
|
561
|
+
return select_index
|
562
|
+
|
563
|
+
@synchronized
|
564
|
+
def is_reserved(self, indices: torch.Tensor) -> bool:
|
565
|
+
return self.get_state(indices) == MemoryStateInt.RESERVED
|
566
|
+
|
567
|
+
@synchronized
|
568
|
+
def is_protected(self, indices: torch.Tensor) -> bool:
|
569
|
+
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
570
|
+
|
571
|
+
@synchronized
|
572
|
+
def is_synced(self, indices: torch.Tensor) -> bool:
|
573
|
+
return self.get_state(indices) == MemoryStateInt.SYNCED
|
574
|
+
|
575
|
+
@synchronized
|
576
|
+
def is_backup(self, indices: torch.Tensor) -> bool:
|
577
|
+
return self.get_state(indices) == MemoryStateInt.BACKUP
|
578
|
+
|
579
|
+
@synchronized
|
580
|
+
def update_backup(self, indices: torch.Tensor):
|
581
|
+
assert self.is_synced(indices), (
|
582
|
+
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
583
|
+
f"Current state: {self.get_state(indices)}"
|
584
|
+
)
|
585
|
+
self.mem_state[indices] = MemoryStateInt.BACKUP
|
586
|
+
|
587
|
+
@synchronized
|
588
|
+
def update_synced(self, indices: torch.Tensor):
|
589
|
+
self.mem_state[indices] = MemoryStateInt.SYNCED
|
590
|
+
|
591
|
+
@synchronized
|
592
|
+
def protect_write(self, indices: torch.Tensor):
|
593
|
+
assert self.is_reserved(indices), (
|
594
|
+
f"The host memory slots should be RESERVED before write operations. "
|
595
|
+
f"Current state: {self.get_state(indices)}"
|
596
|
+
)
|
597
|
+
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
598
|
+
|
599
|
+
@synchronized
|
600
|
+
def protect_load(self, indices: torch.Tensor):
|
601
|
+
assert self.is_backup(indices), (
|
602
|
+
f"The host memory slots should be in BACKUP state before load operations. "
|
603
|
+
f"Current state: {self.get_state(indices)}"
|
604
|
+
)
|
605
|
+
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
606
|
+
|
607
|
+
@synchronized
|
608
|
+
def complete_io(self, indices: torch.Tensor):
|
609
|
+
assert self.is_protected(indices), (
|
610
|
+
f"The host memory slots should be PROTECTED during I/O operations. "
|
611
|
+
f"Current state: {self.get_state(indices)}"
|
612
|
+
)
|
613
|
+
self.mem_state[indices] = MemoryStateInt.SYNCED
|
614
|
+
|
615
|
+
def available_size(self):
|
616
|
+
return len(self.free_slots)
|
617
|
+
|
618
|
+
@synchronized
|
619
|
+
def free(self, indices: torch.Tensor) -> int:
|
620
|
+
self.mem_state[indices] = MemoryStateInt.IDLE
|
621
|
+
self.free_slots = torch.concat([self.free_slots, indices])
|
622
|
+
self.can_use_mem_size += len(indices)
|
623
|
+
return len(indices)
|
sglang/srt/metrics/collector.py
CHANGED
@@ -109,31 +109,31 @@ class TokenizerMetricsCollector:
|
|
109
109
|
labelnames=labels.keys(),
|
110
110
|
)
|
111
111
|
|
112
|
+
self.num_requests_total = Counter(
|
113
|
+
name="sglang:num_requests_total",
|
114
|
+
documentation="Number of requests processed.",
|
115
|
+
labelnames=labels.keys(),
|
116
|
+
)
|
117
|
+
|
112
118
|
self.histogram_time_to_first_token = Histogram(
|
113
119
|
name="sglang:time_to_first_token_seconds",
|
114
120
|
documentation="Histogram of time to first token in seconds.",
|
115
121
|
labelnames=labels.keys(),
|
116
122
|
buckets=[
|
117
|
-
0.001,
|
118
|
-
0.005,
|
119
|
-
0.01,
|
120
|
-
0.02,
|
121
|
-
0.04,
|
122
|
-
0.06,
|
123
|
-
0.08,
|
124
123
|
0.1,
|
125
124
|
0.25,
|
126
125
|
0.5,
|
127
126
|
0.75,
|
128
|
-
1
|
129
|
-
2
|
130
|
-
5
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
127
|
+
1,
|
128
|
+
2,
|
129
|
+
5,
|
130
|
+
10,
|
131
|
+
20,
|
132
|
+
40,
|
133
|
+
60,
|
134
|
+
80,
|
135
|
+
120,
|
136
|
+
160,
|
137
137
|
],
|
138
138
|
)
|
139
139
|
|
@@ -168,21 +168,19 @@ class TokenizerMetricsCollector:
|
|
168
168
|
documentation="Histogram of End-to-end request latency in seconds",
|
169
169
|
labelnames=labels.keys(),
|
170
170
|
buckets=[
|
171
|
-
0.
|
171
|
+
0.1,
|
172
|
+
0.25,
|
172
173
|
0.5,
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
40.0,
|
184
|
-
50.0,
|
185
|
-
60.0,
|
174
|
+
1,
|
175
|
+
2,
|
176
|
+
5,
|
177
|
+
10,
|
178
|
+
20,
|
179
|
+
40,
|
180
|
+
60,
|
181
|
+
80,
|
182
|
+
120,
|
183
|
+
160,
|
186
184
|
],
|
187
185
|
)
|
188
186
|
|
@@ -193,11 +191,10 @@ class TokenizerMetricsCollector:
|
|
193
191
|
# Convenience function for logging to counter.
|
194
192
|
counter.labels(**self.labels).inc(data)
|
195
193
|
|
196
|
-
def
|
197
|
-
self.
|
198
|
-
|
199
|
-
|
200
|
-
self._log_counter(self.generation_tokens_total, value)
|
194
|
+
def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
|
195
|
+
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
196
|
+
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
197
|
+
self.num_requests_total.labels(**self.labels).inc(1)
|
201
198
|
|
202
199
|
def observe_time_to_first_token(self, value: Union[float, int]):
|
203
200
|
self._log_histogram(self.histogram_time_to_first_token, value)
|
@@ -124,10 +124,12 @@ class CudaGraphRunner:
|
|
124
124
|
self.tp_size = self.model_runner.tp_size
|
125
125
|
|
126
126
|
# Batch sizes to capture
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
127
|
+
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
|
128
|
+
if self.capture_bs is None:
|
129
|
+
if model_runner.server_args.disable_cuda_graph_padding:
|
130
|
+
self.capture_bs = list(range(1, 33)) + [64, 128]
|
131
|
+
else:
|
132
|
+
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
131
133
|
|
132
134
|
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
|
133
135
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
@@ -322,6 +324,8 @@ class CudaGraphRunner:
|
|
322
324
|
global_num_tokens = None
|
323
325
|
gathered_buffer = None
|
324
326
|
|
327
|
+
spec_info = self.get_spec_info(num_tokens, positions)
|
328
|
+
|
325
329
|
forward_batch = ForwardBatch(
|
326
330
|
forward_mode=self.capture_forward_mode,
|
327
331
|
batch_size=bs,
|
@@ -338,10 +342,13 @@ class CudaGraphRunner:
|
|
338
342
|
top_logprobs_nums=[0] * bs,
|
339
343
|
positions=positions,
|
340
344
|
global_num_tokens=global_num_tokens,
|
341
|
-
mrope_positions=mrope_positions,
|
342
345
|
gathered_buffer=gathered_buffer,
|
346
|
+
mrope_positions=mrope_positions,
|
343
347
|
spec_algorithm=self.model_runner.spec_algorithm,
|
344
|
-
spec_info=
|
348
|
+
spec_info=spec_info,
|
349
|
+
capture_hidden_mode=(
|
350
|
+
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
351
|
+
),
|
345
352
|
)
|
346
353
|
|
347
354
|
# Attention backend
|
@@ -446,10 +453,10 @@ class CudaGraphRunner:
|
|
446
453
|
|
447
454
|
if self.model_runner.is_draft_worker:
|
448
455
|
spec_info = EAGLEDraftInput()
|
456
|
+
spec_info.load_server_args(self.model_runner.server_args)
|
449
457
|
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
450
458
|
spec_info.positions = positions
|
451
459
|
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
452
|
-
spec_info.init(self.model_runner.server_args)
|
453
460
|
else:
|
454
461
|
spec_info = EagleVerifyInput(
|
455
462
|
None,
|
@@ -106,6 +106,24 @@ class ForwardMode(IntEnum):
|
|
106
106
|
def is_dummy_first(self):
|
107
107
|
return self == ForwardMode.DUMMY_FIRST
|
108
108
|
|
109
|
+
def is_decode_or_idle(self):
|
110
|
+
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
111
|
+
|
112
|
+
|
113
|
+
class CaptureHiddenMode(IntEnum):
|
114
|
+
NULL = auto()
|
115
|
+
FULL = auto()
|
116
|
+
LAST = auto()
|
117
|
+
|
118
|
+
def need_capture(self):
|
119
|
+
return self != CaptureHiddenMode.NULL
|
120
|
+
|
121
|
+
def is_full(self):
|
122
|
+
return self == CaptureHiddenMode.FULL
|
123
|
+
|
124
|
+
def is_last(self):
|
125
|
+
return self == CaptureHiddenMode.LAST
|
126
|
+
|
109
127
|
|
110
128
|
@dataclass
|
111
129
|
class ForwardBatch:
|
@@ -174,6 +192,7 @@ class ForwardBatch:
|
|
174
192
|
# Speculative decoding
|
175
193
|
spec_info: SpecInfo = None
|
176
194
|
spec_algorithm: SpeculativeAlgorithm = None
|
195
|
+
capture_hidden_mode: CaptureHiddenMode = None
|
177
196
|
|
178
197
|
# For Qwen2-VL
|
179
198
|
mrope_positions: torch.Tensor = None
|
@@ -265,6 +284,7 @@ class ForwardBatch:
|
|
265
284
|
sampling_info=batch.sampling_info,
|
266
285
|
spec_algorithm=batch.spec_algorithm,
|
267
286
|
spec_info=batch.spec_info,
|
287
|
+
capture_hidden_mode=batch.capture_hidden_mode,
|
268
288
|
input_embeds=batch.input_embeds,
|
269
289
|
)
|
270
290
|
|
@@ -400,18 +420,3 @@ def compute_position_torch(
|
|
400
420
|
@maybe_torch_compile(dynamic=True)
|
401
421
|
def clamp_position(seq_lens):
|
402
422
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
403
|
-
|
404
|
-
|
405
|
-
class CaptureHiddenMode(IntEnum):
|
406
|
-
NULL = auto()
|
407
|
-
FULL = auto()
|
408
|
-
LAST = auto()
|
409
|
-
|
410
|
-
def need_capture(self):
|
411
|
-
return self != CaptureHiddenMode.NULL
|
412
|
-
|
413
|
-
def is_full(self):
|
414
|
-
return self == CaptureHiddenMode.FULL
|
415
|
-
|
416
|
-
def is_last(self):
|
417
|
-
return self == CaptureHiddenMode.LAST
|