sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +16 -7
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +21 -5
  11. sglang/srt/layers/linear.py +89 -47
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
  15. sglang/srt/layers/moe/topk.py +4 -2
  16. sglang/srt/layers/parameter.py +439 -0
  17. sglang/srt/layers/quantization/__init__.py +5 -2
  18. sglang/srt/layers/quantization/fp8.py +107 -53
  19. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  20. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  21. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  22. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  23. sglang/srt/layers/radix_attention.py +2 -0
  24. sglang/srt/layers/vocab_parallel_embedding.py +16 -3
  25. sglang/srt/managers/cache_controller.py +307 -0
  26. sglang/srt/managers/configure_logging.py +43 -0
  27. sglang/srt/managers/data_parallel_controller.py +2 -0
  28. sglang/srt/managers/detokenizer_manager.py +0 -2
  29. sglang/srt/managers/io_struct.py +29 -13
  30. sglang/srt/managers/schedule_batch.py +7 -1
  31. sglang/srt/managers/scheduler.py +58 -15
  32. sglang/srt/managers/session_controller.py +1 -1
  33. sglang/srt/managers/tokenizer_manager.py +109 -45
  34. sglang/srt/mem_cache/memory_pool.py +313 -53
  35. sglang/srt/metrics/collector.py +32 -35
  36. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  37. sglang/srt/model_executor/forward_batch_info.py +20 -15
  38. sglang/srt/model_executor/model_runner.py +53 -10
  39. sglang/srt/models/chatglm.py +1 -1
  40. sglang/srt/models/dbrx.py +1 -1
  41. sglang/srt/models/grok.py +25 -16
  42. sglang/srt/models/llama.py +46 -4
  43. sglang/srt/models/qwen2.py +11 -0
  44. sglang/srt/models/qwen2_eagle.py +131 -0
  45. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  46. sglang/srt/sampling/sampling_batch_info.py +15 -5
  47. sglang/srt/sampling/sampling_params.py +1 -1
  48. sglang/srt/server.py +125 -69
  49. sglang/srt/server_args.py +39 -19
  50. sglang/srt/speculative/eagle_utils.py +93 -85
  51. sglang/srt/speculative/eagle_worker.py +48 -33
  52. sglang/srt/torch_memory_saver_adapter.py +59 -0
  53. sglang/srt/utils.py +61 -5
  54. sglang/test/test_programs.py +23 -1
  55. sglang/test/test_utils.py +36 -7
  56. sglang/version.py +1 -1
  57. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
  58. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
  59. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  61. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17
+
16
18
  """
17
19
  Memory pool.
18
20
 
@@ -22,26 +24,45 @@ BaseTokenToKVPool maps a token location to its KV cache data.
22
24
  """
23
25
 
24
26
  import logging
27
+ import threading
28
+ from enum import IntEnum
29
+ from functools import wraps
25
30
  from typing import List, Tuple, Union
26
31
 
32
+ import numpy as np
33
+ import psutil
27
34
  import torch
28
35
 
29
36
  from sglang.srt.layers.radix_attention import RadixAttention
30
- from sglang.srt.utils import get_compiler_backend
37
+ from sglang.srt.utils import debug_timing, get_compiler_backend
31
38
 
32
39
  logger = logging.getLogger(__name__)
33
40
 
41
+ GB = 1024 * 1024 * 1024
42
+
34
43
 
35
44
  class ReqToTokenPool:
36
45
  """A memory pool that maps a request to its token locations."""
37
46
 
38
- def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
47
+ def __init__(
48
+ self,
49
+ size: int,
50
+ max_context_len: int,
51
+ device: str,
52
+ use_records: bool,
53
+ enable_memory_saver: bool,
54
+ ):
55
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
56
+ enable=enable_memory_saver
57
+ )
58
+
39
59
  self.size = size
40
60
  self.max_context_len = max_context_len
41
61
  self.device = device
42
- self.req_to_token = torch.zeros(
43
- (size, max_context_len), dtype=torch.int32, device=device
44
- )
62
+ with memory_saver_adapter.region():
63
+ self.req_to_token = torch.zeros(
64
+ (size, max_context_len), dtype=torch.int32, device=device
65
+ )
45
66
  self.free_slots = list(range(size))
46
67
  self.write_records = []
47
68
  self.use_records = use_records
@@ -105,8 +126,8 @@ class BaseTokenToKVPool:
105
126
  ):
106
127
  self.size = size
107
128
  self.dtype = dtype
108
- if dtype == torch.float8_e5m2:
109
- # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
129
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
130
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
110
131
  self.store_dtype = torch.uint8
111
132
  else:
112
133
  self.store_dtype = dtype
@@ -182,37 +203,80 @@ class MHATokenToKVPool(BaseTokenToKVPool):
182
203
  head_dim: int,
183
204
  layer_num: int,
184
205
  device: str,
206
+ enable_memory_saver: bool,
185
207
  ):
186
208
  super().__init__(size, dtype, device)
209
+
210
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
211
+ enable=enable_memory_saver
212
+ )
213
+
187
214
  self.head_num = head_num
188
215
  self.head_dim = head_dim
189
216
  self.layer_num = layer_num
190
217
  self._create_buffers()
191
218
 
219
+ k_size, v_size = self.get_kv_size_bytes()
220
+ logger.info(
221
+ f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
222
+ )
223
+
192
224
  def _create_buffers(self):
193
- # [size, head_num, head_dim] for each layer
194
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
195
- self.k_buffer = [
196
- torch.empty(
197
- (self.size + 1, self.head_num, self.head_dim),
198
- dtype=self.store_dtype,
199
- device=self.device,
200
- )
201
- for _ in range(self.layer_num)
202
- ]
203
- self.v_buffer = [
204
- torch.empty(
205
- (self.size + 1, self.head_num, self.head_dim),
206
- dtype=self.store_dtype,
207
- device=self.device,
208
- )
209
- for _ in range(self.layer_num)
210
- ]
225
+ with self.memory_saver_adapter.region():
226
+ # [size, head_num, head_dim] for each layer
227
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
228
+ self.k_buffer = [
229
+ torch.empty(
230
+ (self.size + 1, self.head_num, self.head_dim),
231
+ dtype=self.store_dtype,
232
+ device=self.device,
233
+ )
234
+ for _ in range(self.layer_num)
235
+ ]
236
+ self.v_buffer = [
237
+ torch.empty(
238
+ (self.size + 1, self.head_num, self.head_dim),
239
+ dtype=self.store_dtype,
240
+ device=self.device,
241
+ )
242
+ for _ in range(self.layer_num)
243
+ ]
211
244
 
212
245
  def _clear_buffers(self):
213
246
  del self.k_buffer
214
247
  del self.v_buffer
215
248
 
249
+ def get_kv_size_bytes(self):
250
+ assert hasattr(self, "k_buffer")
251
+ assert hasattr(self, "v_buffer")
252
+ k_size_bytes = 0
253
+ for k_cache in self.k_buffer:
254
+ k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
255
+ v_size_bytes = 0
256
+ for v_cache in self.v_buffer:
257
+ v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
258
+ return k_size_bytes, v_size_bytes
259
+
260
+ # Todo: different memory layout
261
+ def get_flat_data(self, indices):
262
+ # prepare a large chunk of contiguous data for efficient transfer
263
+ flatten = torch.stack(
264
+ [
265
+ torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
266
+ torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
267
+ ]
268
+ )
269
+ return flatten
270
+
271
+ @debug_timing
272
+ def transfer(self, indices, flat_data):
273
+ # transfer prepared data from host to device
274
+ flat_data = flat_data.to(device=self.device, non_blocking=False)
275
+ k_data, v_data = flat_data[0], flat_data[1]
276
+ for i in range(self.layer_num):
277
+ self.k_buffer[i][indices] = k_data[i]
278
+ self.v_buffer[i][indices] = v_data[i]
279
+
216
280
  def get_key_buffer(self, layer_id: int):
217
281
  if self.store_dtype != self.dtype:
218
282
  return self.k_buffer[layer_id].view(self.dtype)
@@ -232,11 +296,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
232
296
  loc: torch.Tensor,
233
297
  cache_k: torch.Tensor,
234
298
  cache_v: torch.Tensor,
299
+ k_scale: float = 1.0,
300
+ v_scale: float = 1.0,
235
301
  ):
236
302
  layer_id = layer.layer_id
237
303
  if cache_k.dtype != self.dtype:
238
- cache_k = cache_k.to(self.dtype)
239
- cache_v = cache_v.to(self.dtype)
304
+ cache_k = (cache_k / k_scale).to(self.dtype)
305
+ cache_v = (cache_v / v_scale).to(self.dtype)
240
306
  if self.store_dtype != self.dtype:
241
307
  self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
242
308
  self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
@@ -262,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
262
328
  qk_rope_head_dim: int,
263
329
  layer_num: int,
264
330
  device: str,
331
+ enable_memory_saver: bool,
265
332
  ):
266
333
  super().__init__(size, dtype, device)
267
334
 
268
335
  self.kv_lora_rank = kv_lora_rank
269
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
270
- self.kv_buffer = [
271
- torch.empty(
272
- (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
273
- dtype=self.store_dtype,
274
- device=device,
275
- )
276
- for _ in range(layer_num)
277
- ]
336
+
337
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
338
+ enable=enable_memory_saver
339
+ )
340
+
341
+ with memory_saver_adapter.region():
342
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
343
+ self.kv_buffer = [
344
+ torch.empty(
345
+ (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
346
+ dtype=self.store_dtype,
347
+ device=device,
348
+ )
349
+ for _ in range(layer_num)
350
+ ]
278
351
 
279
352
  def get_key_buffer(self, layer_id: int):
280
353
  if self.store_dtype != self.dtype:
@@ -315,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
315
388
  layer_num: int,
316
389
  device: str,
317
390
  heavy_channel_num: int,
391
+ enable_memory_saver: bool,
318
392
  ):
319
393
  super().__init__(size, dtype, device)
320
394
 
321
- # [size, head_num, head_dim] for each layer
322
- self.k_buffer = [
323
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
324
- for _ in range(layer_num)
325
- ]
326
- self.v_buffer = [
327
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
328
- for _ in range(layer_num)
329
- ]
330
-
331
- # [size, head_num, heavy_channel_num] for each layer
332
- self.label_buffer = [
333
- torch.empty(
334
- (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
335
- )
336
- for _ in range(layer_num)
337
- ]
395
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
396
+ enable=enable_memory_saver
397
+ )
398
+
399
+ with memory_saver_adapter.region():
400
+ # [size, head_num, head_dim] for each layer
401
+ self.k_buffer = [
402
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
403
+ for _ in range(layer_num)
404
+ ]
405
+ self.v_buffer = [
406
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
407
+ for _ in range(layer_num)
408
+ ]
409
+
410
+ # [size, head_num, heavy_channel_num] for each layer
411
+ self.label_buffer = [
412
+ torch.empty(
413
+ (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
414
+ )
415
+ for _ in range(layer_num)
416
+ ]
338
417
 
339
418
  def get_key_buffer(self, layer_id: int):
340
419
  return self.k_buffer[layer_id]
@@ -361,3 +440,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
361
440
  self.k_buffer[layer_id][loc] = cache_k
362
441
  self.v_buffer[layer_id][loc] = cache_v
363
442
  self.label_buffer[layer_id][loc] = cache_label
443
+
444
+
445
+ class MemoryStateInt(IntEnum):
446
+ IDLE = 0
447
+ RESERVED = 1
448
+ PROTECTED = 2
449
+ SYNCED = 3
450
+ BACKUP = 4
451
+
452
+
453
+ def synchronized(func):
454
+ @wraps(func)
455
+ def wrapper(self, *args, **kwargs):
456
+ with self.lock:
457
+ return func(self, *args, **kwargs)
458
+
459
+ return wrapper
460
+
461
+
462
+ class MLATokenToKVPoolHost:
463
+
464
+ def __init__(
465
+ self,
466
+ device_pool: MHATokenToKVPool,
467
+ host_to_device_ratio: float = 2.0,
468
+ pin_memory: bool = False, # no need to use pin memory with the double buffering
469
+ device: str = "cpu",
470
+ ):
471
+ assert (
472
+ host_to_device_ratio >= 1
473
+ ), "The host memory should be larger than the device memory with the current protocol"
474
+ # todo, other ways of configuring the size
475
+
476
+ self.device_pool = device_pool
477
+ self.host_to_device_ratio = host_to_device_ratio
478
+ self.pin_memory = pin_memory
479
+ self.device = device
480
+
481
+ self.size = int(device_pool.size * host_to_device_ratio)
482
+ self.dtype = device_pool.store_dtype
483
+ self.head_num = device_pool.head_num
484
+ self.head_dim = device_pool.head_dim
485
+ self.layer_num = device_pool.layer_num
486
+ self.size_per_token = (
487
+ self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
488
+ )
489
+
490
+ # Verify there is enough available host memory.
491
+ host_mem = psutil.virtual_memory()
492
+ requested_bytes = self.size * self.size_per_token
493
+ # preserve at least 10GB for other usage
494
+ ten_gb = 10 * (1024**3)
495
+ if requested_bytes > host_mem.available - ten_gb:
496
+ raise ValueError(
497
+ f"Not enough host memory available. Requesting "
498
+ f"{requested_bytes / 1e9:.2f} GB but only have "
499
+ f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
500
+ f"size of the hierarchical cache."
501
+ )
502
+ else:
503
+ logger.info(
504
+ f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
505
+ )
506
+
507
+ self.kv_buffer = torch.empty(
508
+ (2, self.layer_num, self.size, self.head_num, self.head_dim),
509
+ dtype=self.dtype,
510
+ device=self.device,
511
+ pin_memory=self.pin_memory,
512
+ )
513
+
514
+ # Initialize memory states and tracking structures.
515
+ self.mem_state = torch.zeros(
516
+ (self.size,), dtype=torch.uint8, device=self.device
517
+ )
518
+ self.free_slots = torch.arange(self.size, dtype=torch.int32)
519
+ self.can_use_mem_size = self.size
520
+
521
+ # A lock for synchronized operations on memory allocation and state transitions.
522
+ self.lock = threading.RLock()
523
+
524
+ def get_flat_data(self, indices):
525
+ return self.kv_buffer[:, :, indices]
526
+
527
+ @debug_timing
528
+ def transfer(self, indices, flat_data):
529
+ # backup prepared data from device to host
530
+ self.kv_buffer[:, :, indices] = flat_data.to(
531
+ device=self.device, non_blocking=False
532
+ )
533
+
534
+ @synchronized
535
+ def clear(self):
536
+ self.mem_state.fill_(0)
537
+ self.can_use_mem_size = self.size
538
+ self.free_slots = torch.arange(self.size, dtype=torch.int32)
539
+
540
+ @synchronized
541
+ def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
542
+ assert len(indices) > 0, "The indices should not be empty"
543
+ states = self.mem_state[indices]
544
+ assert (
545
+ states == states[0]
546
+ ).all(), "The memory slots should have the same state {}".format(states)
547
+ return MemoryStateInt(states[0].item())
548
+
549
+ @synchronized
550
+ def alloc(self, need_size: int) -> torch.Tensor:
551
+ if need_size > self.can_use_mem_size:
552
+ return None
553
+
554
+ # todo: de-fragementation
555
+ select_index = self.free_slots[:need_size]
556
+ self.free_slots = self.free_slots[need_size:]
557
+
558
+ self.mem_state[select_index] = MemoryStateInt.RESERVED
559
+ self.can_use_mem_size -= need_size
560
+
561
+ return select_index
562
+
563
+ @synchronized
564
+ def is_reserved(self, indices: torch.Tensor) -> bool:
565
+ return self.get_state(indices) == MemoryStateInt.RESERVED
566
+
567
+ @synchronized
568
+ def is_protected(self, indices: torch.Tensor) -> bool:
569
+ return self.get_state(indices) == MemoryStateInt.PROTECTED
570
+
571
+ @synchronized
572
+ def is_synced(self, indices: torch.Tensor) -> bool:
573
+ return self.get_state(indices) == MemoryStateInt.SYNCED
574
+
575
+ @synchronized
576
+ def is_backup(self, indices: torch.Tensor) -> bool:
577
+ return self.get_state(indices) == MemoryStateInt.BACKUP
578
+
579
+ @synchronized
580
+ def update_backup(self, indices: torch.Tensor):
581
+ assert self.is_synced(indices), (
582
+ f"The host memory slots should be in SYNCED state before turning into BACKUP. "
583
+ f"Current state: {self.get_state(indices)}"
584
+ )
585
+ self.mem_state[indices] = MemoryStateInt.BACKUP
586
+
587
+ @synchronized
588
+ def update_synced(self, indices: torch.Tensor):
589
+ self.mem_state[indices] = MemoryStateInt.SYNCED
590
+
591
+ @synchronized
592
+ def protect_write(self, indices: torch.Tensor):
593
+ assert self.is_reserved(indices), (
594
+ f"The host memory slots should be RESERVED before write operations. "
595
+ f"Current state: {self.get_state(indices)}"
596
+ )
597
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
598
+
599
+ @synchronized
600
+ def protect_load(self, indices: torch.Tensor):
601
+ assert self.is_backup(indices), (
602
+ f"The host memory slots should be in BACKUP state before load operations. "
603
+ f"Current state: {self.get_state(indices)}"
604
+ )
605
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
606
+
607
+ @synchronized
608
+ def complete_io(self, indices: torch.Tensor):
609
+ assert self.is_protected(indices), (
610
+ f"The host memory slots should be PROTECTED during I/O operations. "
611
+ f"Current state: {self.get_state(indices)}"
612
+ )
613
+ self.mem_state[indices] = MemoryStateInt.SYNCED
614
+
615
+ def available_size(self):
616
+ return len(self.free_slots)
617
+
618
+ @synchronized
619
+ def free(self, indices: torch.Tensor) -> int:
620
+ self.mem_state[indices] = MemoryStateInt.IDLE
621
+ self.free_slots = torch.concat([self.free_slots, indices])
622
+ self.can_use_mem_size += len(indices)
623
+ return len(indices)
@@ -109,31 +109,31 @@ class TokenizerMetricsCollector:
109
109
  labelnames=labels.keys(),
110
110
  )
111
111
 
112
+ self.num_requests_total = Counter(
113
+ name="sglang:num_requests_total",
114
+ documentation="Number of requests processed.",
115
+ labelnames=labels.keys(),
116
+ )
117
+
112
118
  self.histogram_time_to_first_token = Histogram(
113
119
  name="sglang:time_to_first_token_seconds",
114
120
  documentation="Histogram of time to first token in seconds.",
115
121
  labelnames=labels.keys(),
116
122
  buckets=[
117
- 0.001,
118
- 0.005,
119
- 0.01,
120
- 0.02,
121
- 0.04,
122
- 0.06,
123
- 0.08,
124
123
  0.1,
125
124
  0.25,
126
125
  0.5,
127
126
  0.75,
128
- 1.0,
129
- 2.5,
130
- 5.0,
131
- 7.5,
132
- 10.0,
133
- 15.0,
134
- 20.0,
135
- 25.0,
136
- 30.0,
127
+ 1,
128
+ 2,
129
+ 5,
130
+ 10,
131
+ 20,
132
+ 40,
133
+ 60,
134
+ 80,
135
+ 120,
136
+ 160,
137
137
  ],
138
138
  )
139
139
 
@@ -168,21 +168,19 @@ class TokenizerMetricsCollector:
168
168
  documentation="Histogram of End-to-end request latency in seconds",
169
169
  labelnames=labels.keys(),
170
170
  buckets=[
171
- 0.3,
171
+ 0.1,
172
+ 0.25,
172
173
  0.5,
173
- 0.8,
174
- 1.0,
175
- 1.5,
176
- 2.0,
177
- 2.5,
178
- 5.0,
179
- 10.0,
180
- 15.0,
181
- 20.0,
182
- 30.0,
183
- 40.0,
184
- 50.0,
185
- 60.0,
174
+ 1,
175
+ 2,
176
+ 5,
177
+ 10,
178
+ 20,
179
+ 40,
180
+ 60,
181
+ 80,
182
+ 120,
183
+ 160,
186
184
  ],
187
185
  )
188
186
 
@@ -193,11 +191,10 @@ class TokenizerMetricsCollector:
193
191
  # Convenience function for logging to counter.
194
192
  counter.labels(**self.labels).inc(data)
195
193
 
196
- def inc_prompt_tokens(self, value: int):
197
- self._log_counter(self.prompt_tokens_total, value)
198
-
199
- def inc_generation_tokens(self, value: int):
200
- self._log_counter(self.generation_tokens_total, value)
194
+ def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
195
+ self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
196
+ self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
197
+ self.num_requests_total.labels(**self.labels).inc(1)
201
198
 
202
199
  def observe_time_to_first_token(self, value: Union[float, int]):
203
200
  self._log_histogram(self.histogram_time_to_first_token, value)
@@ -124,10 +124,12 @@ class CudaGraphRunner:
124
124
  self.tp_size = self.model_runner.tp_size
125
125
 
126
126
  # Batch sizes to capture
127
- if model_runner.server_args.disable_cuda_graph_padding:
128
- self.capture_bs = list(range(1, 33)) + [64, 128]
129
- else:
130
- self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
127
+ self.capture_bs = self.model_runner.server_args.cuda_graph_bs
128
+ if self.capture_bs is None:
129
+ if model_runner.server_args.disable_cuda_graph_padding:
130
+ self.capture_bs = list(range(1, 33)) + [64, 128]
131
+ else:
132
+ self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
131
133
 
132
134
  if max(self.capture_bs) > model_runner.req_to_token_pool.size:
133
135
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
@@ -322,6 +324,8 @@ class CudaGraphRunner:
322
324
  global_num_tokens = None
323
325
  gathered_buffer = None
324
326
 
327
+ spec_info = self.get_spec_info(num_tokens, positions)
328
+
325
329
  forward_batch = ForwardBatch(
326
330
  forward_mode=self.capture_forward_mode,
327
331
  batch_size=bs,
@@ -338,10 +342,13 @@ class CudaGraphRunner:
338
342
  top_logprobs_nums=[0] * bs,
339
343
  positions=positions,
340
344
  global_num_tokens=global_num_tokens,
341
- mrope_positions=mrope_positions,
342
345
  gathered_buffer=gathered_buffer,
346
+ mrope_positions=mrope_positions,
343
347
  spec_algorithm=self.model_runner.spec_algorithm,
344
- spec_info=self.get_spec_info(num_tokens, positions),
348
+ spec_info=spec_info,
349
+ capture_hidden_mode=(
350
+ spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
351
+ ),
345
352
  )
346
353
 
347
354
  # Attention backend
@@ -446,10 +453,10 @@ class CudaGraphRunner:
446
453
 
447
454
  if self.model_runner.is_draft_worker:
448
455
  spec_info = EAGLEDraftInput()
456
+ spec_info.load_server_args(self.model_runner.server_args)
449
457
  spec_info.hidden_states = self.hidden_states[:num_tokens]
450
458
  spec_info.positions = positions
451
459
  spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
452
- spec_info.init(self.model_runner.server_args)
453
460
  else:
454
461
  spec_info = EagleVerifyInput(
455
462
  None,
@@ -106,6 +106,24 @@ class ForwardMode(IntEnum):
106
106
  def is_dummy_first(self):
107
107
  return self == ForwardMode.DUMMY_FIRST
108
108
 
109
+ def is_decode_or_idle(self):
110
+ return self == ForwardMode.DECODE or self == ForwardMode.IDLE
111
+
112
+
113
+ class CaptureHiddenMode(IntEnum):
114
+ NULL = auto()
115
+ FULL = auto()
116
+ LAST = auto()
117
+
118
+ def need_capture(self):
119
+ return self != CaptureHiddenMode.NULL
120
+
121
+ def is_full(self):
122
+ return self == CaptureHiddenMode.FULL
123
+
124
+ def is_last(self):
125
+ return self == CaptureHiddenMode.LAST
126
+
109
127
 
110
128
  @dataclass
111
129
  class ForwardBatch:
@@ -174,6 +192,7 @@ class ForwardBatch:
174
192
  # Speculative decoding
175
193
  spec_info: SpecInfo = None
176
194
  spec_algorithm: SpeculativeAlgorithm = None
195
+ capture_hidden_mode: CaptureHiddenMode = None
177
196
 
178
197
  # For Qwen2-VL
179
198
  mrope_positions: torch.Tensor = None
@@ -265,6 +284,7 @@ class ForwardBatch:
265
284
  sampling_info=batch.sampling_info,
266
285
  spec_algorithm=batch.spec_algorithm,
267
286
  spec_info=batch.spec_info,
287
+ capture_hidden_mode=batch.capture_hidden_mode,
268
288
  input_embeds=batch.input_embeds,
269
289
  )
270
290
 
@@ -400,18 +420,3 @@ def compute_position_torch(
400
420
  @maybe_torch_compile(dynamic=True)
401
421
  def clamp_position(seq_lens):
402
422
  return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
403
-
404
-
405
- class CaptureHiddenMode(IntEnum):
406
- NULL = auto()
407
- FULL = auto()
408
- LAST = auto()
409
-
410
- def need_capture(self):
411
- return self != CaptureHiddenMode.NULL
412
-
413
- def is_full(self):
414
- return self == CaptureHiddenMode.FULL
415
-
416
- def is_last(self):
417
- return self == CaptureHiddenMode.LAST