sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -30,7 +30,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
30
30
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
31
31
  from sglang.srt.model_executor.model_runner import ModelRunner
32
32
  from sglang.srt.server_args import ServerArgs
33
- from sglang.srt.utils import broadcast_pyobj, set_random_seed
33
+ from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
34
34
 
35
35
  logger = logging.getLogger(__name__)
36
36
 
@@ -45,13 +45,18 @@ class TpModelWorker:
45
45
  tp_rank: int,
46
46
  dp_rank: Optional[int],
47
47
  nccl_port: int,
48
+ is_draft_worker: bool = False,
48
49
  ):
49
50
  # Parse args
50
51
  self.tp_rank = tp_rank
51
52
 
52
53
  # Init model and tokenizer
53
54
  self.model_config = ModelConfig(
54
- server_args.model_path,
55
+ (
56
+ server_args.model_path
57
+ if not is_draft_worker
58
+ else server_args.speculative_draft_model_path
59
+ ),
55
60
  trust_remote_code=server_args.trust_remote_code,
56
61
  revision=server_args.revision,
57
62
  context_length=server_args.context_length,
@@ -68,6 +73,7 @@ class TpModelWorker:
68
73
  tp_size=server_args.tp_size,
69
74
  nccl_port=nccl_port,
70
75
  server_args=server_args,
76
+ is_draft_worker=is_draft_worker,
71
77
  )
72
78
  if server_args.skip_tokenizer_init:
73
79
  self.tokenizer = self.processor = None
@@ -150,12 +156,18 @@ class TpModelWorker:
150
156
  self,
151
157
  model_worker_batch: ModelWorkerBatch,
152
158
  launch_done: Optional[threading.Event] = None,
159
+ skip_sample: bool = False,
153
160
  ):
154
161
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
155
162
  logits_output = self.model_runner.forward(forward_batch)
156
163
  if launch_done:
157
164
  launch_done.set()
158
- next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
165
+
166
+ if skip_sample:
167
+ next_token_ids = None
168
+ else:
169
+ next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
170
+
159
171
  return logits_output, next_token_ids
160
172
 
161
173
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
@@ -191,7 +203,7 @@ class TpModelWorker:
191
203
 
192
204
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
193
205
  success, message = self.model_runner.update_weights_from_tensor(
194
- recv_req.name, recv_req.tensor
206
+ MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
195
207
  )
196
208
  return success, message
197
209
 
@@ -144,10 +144,9 @@ class TpModelWorkerClient:
144
144
 
145
145
  # Copy results to the CPU
146
146
  if model_worker_batch.return_logprob:
147
- logits_output.next_token_logprobs = logits_output.next_token_logprobs[
148
- torch.arange(len(next_token_ids), device=self.device),
149
- next_token_ids,
150
- ].to("cpu", non_blocking=True)
147
+ logits_output.next_token_logprobs = (
148
+ logits_output.next_token_logprobs.to("cpu", non_blocking=True)
149
+ )
151
150
  if logits_output.input_token_logprobs is not None:
152
151
  logits_output.input_token_logprobs = (
153
152
  logits_output.input_token_logprobs.to("cpu", non_blocking=True)
@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data.
22
22
  """
23
23
 
24
24
  import logging
25
+ import threading
26
+ from enum import IntEnum
27
+ from functools import wraps
25
28
  from typing import List, Tuple, Union
26
29
 
30
+ import psutil
27
31
  import torch
28
32
 
29
33
  from sglang.srt.layers.radix_attention import RadixAttention
30
- from sglang.srt.utils import get_compiler_backend
34
+ from sglang.srt.utils import debug_timing, get_compiler_backend
31
35
 
32
36
  logger = logging.getLogger(__name__)
33
37
 
@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool):
213
217
  del self.k_buffer
214
218
  del self.v_buffer
215
219
 
220
+ # Todo: different memory layout
221
+ def get_flat_data(self, indices):
222
+ # prepare a large chunk of contiguous data for efficient transfer
223
+ flatten = torch.stack(
224
+ [
225
+ torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
226
+ torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
227
+ ]
228
+ )
229
+ return flatten
230
+
231
+ @debug_timing
232
+ def transfer(self, indices, flat_data):
233
+ # transfer prepared data from host to device
234
+ flat_data = flat_data.to(device=self.device, non_blocking=False)
235
+ k_data, v_data = flat_data[0], flat_data[1]
236
+ for i in range(self.layer_num):
237
+ self.k_buffer[i][indices] = k_data[i]
238
+ self.v_buffer[i][indices] = v_data[i]
239
+
216
240
  def get_key_buffer(self, layer_id: int):
217
241
  if self.store_dtype != self.dtype:
218
242
  return self.k_buffer[layer_id].view(self.dtype)
@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
361
385
  self.k_buffer[layer_id][loc] = cache_k
362
386
  self.v_buffer[layer_id][loc] = cache_v
363
387
  self.label_buffer[layer_id][loc] = cache_label
388
+
389
+
390
+ class MemoryStateInt(IntEnum):
391
+ IDLE = 0
392
+ RESERVED = 1
393
+ PROTECTED = 2
394
+ SYNCED = 3
395
+ BACKUP = 4
396
+
397
+
398
+ def synchronized(func):
399
+ @wraps(func)
400
+ def wrapper(self, *args, **kwargs):
401
+ with self.lock:
402
+ return func(self, *args, **kwargs)
403
+
404
+ return wrapper
405
+
406
+
407
+ class MLATokenToKVPoolHost:
408
+
409
+ def __init__(
410
+ self,
411
+ device_pool: MHATokenToKVPool,
412
+ host_to_device_ratio: float = 2.0,
413
+ pin_memory: bool = False, # no need to use pin memory with the double buffering
414
+ device: str = "cpu",
415
+ ):
416
+ assert (
417
+ host_to_device_ratio >= 1
418
+ ), "The host memory should be larger than the device memory with the current protocol"
419
+ # todo, other ways of configuring the size
420
+
421
+ self.device_pool = device_pool
422
+ self.host_to_device_ratio = host_to_device_ratio
423
+ self.pin_memory = pin_memory
424
+ self.device = device
425
+
426
+ self.size = int(device_pool.size * host_to_device_ratio)
427
+ self.dtype = device_pool.store_dtype
428
+ self.head_num = device_pool.head_num
429
+ self.head_dim = device_pool.head_dim
430
+ self.layer_num = device_pool.layer_num
431
+ self.size_per_token = (
432
+ self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
433
+ )
434
+
435
+ # Verify there is enough available host memory.
436
+ host_mem = psutil.virtual_memory()
437
+ requested_bytes = self.size * self.size_per_token
438
+ # preserve at least 10GB for other usage
439
+ ten_gb = 10 * (1024**3)
440
+ if requested_bytes > host_mem.available - ten_gb:
441
+ raise ValueError(
442
+ f"Not enough host memory available. Requesting "
443
+ f"{requested_bytes / 1e9:.2f} GB but only have "
444
+ f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
445
+ f"size of the hierarchical cache."
446
+ )
447
+ else:
448
+ logger.info(
449
+ f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
450
+ )
451
+
452
+ self.kv_buffer = torch.empty(
453
+ (2, self.layer_num, self.size, self.head_num, self.head_dim),
454
+ dtype=self.dtype,
455
+ device=self.device,
456
+ pin_memory=self.pin_memory,
457
+ )
458
+
459
+ # Initialize memory states and tracking structures.
460
+ self.mem_state = torch.zeros(
461
+ (self.size,), dtype=torch.uint8, device=self.device
462
+ )
463
+ self.free_slots = torch.arange(self.size, dtype=torch.int32)
464
+ self.can_use_mem_size = self.size
465
+
466
+ # A lock for synchronized operations on memory allocation and state transitions.
467
+ self.lock = threading.RLock()
468
+
469
+ def get_flat_data(self, indices):
470
+ return self.kv_buffer[:, :, indices]
471
+
472
+ @debug_timing
473
+ def transfer(self, indices, flat_data):
474
+ # backup prepared data from device to host
475
+ self.kv_buffer[:, :, indices] = flat_data.to(
476
+ device=self.device, non_blocking=False
477
+ )
478
+
479
+ @synchronized
480
+ def clear(self):
481
+ self.mem_state.fill_(0)
482
+ self.can_use_mem_size = self.size
483
+ self.free_slots = torch.arange(self.size, dtype=torch.int32)
484
+
485
+ @synchronized
486
+ def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
487
+ assert len(indices) > 0, "The indices should not be empty"
488
+ states = self.mem_state[indices]
489
+ assert (
490
+ states == states[0]
491
+ ).all(), "The memory slots should have the same state {}".format(states)
492
+ return MemoryStateInt(states[0].item())
493
+
494
+ @synchronized
495
+ def alloc(self, need_size: int) -> torch.Tensor:
496
+ if need_size > self.can_use_mem_size:
497
+ return None
498
+
499
+ # todo: de-fragementation
500
+ select_index = self.free_slots[:need_size]
501
+ self.free_slots = self.free_slots[need_size:]
502
+
503
+ self.mem_state[select_index] = MemoryStateInt.RESERVED
504
+ self.can_use_mem_size -= need_size
505
+
506
+ return select_index
507
+
508
+ @synchronized
509
+ def is_reserved(self, indices: torch.Tensor) -> bool:
510
+ return self.get_state(indices) == MemoryStateInt.RESERVED
511
+
512
+ @synchronized
513
+ def is_protected(self, indices: torch.Tensor) -> bool:
514
+ return self.get_state(indices) == MemoryStateInt.PROTECTED
515
+
516
+ @synchronized
517
+ def is_synced(self, indices: torch.Tensor) -> bool:
518
+ return self.get_state(indices) == MemoryStateInt.SYNCED
519
+
520
+ @synchronized
521
+ def is_backup(self, indices: torch.Tensor) -> bool:
522
+ return self.get_state(indices) == MemoryStateInt.BACKUP
523
+
524
+ @synchronized
525
+ def update_backup(self, indices: torch.Tensor):
526
+ assert self.is_synced(indices), (
527
+ f"The host memory slots should be in SYNCED state before turning into BACKUP. "
528
+ f"Current state: {self.get_state(indices)}"
529
+ )
530
+ self.mem_state[indices] = MemoryStateInt.BACKUP
531
+
532
+ @synchronized
533
+ def update_synced(self, indices: torch.Tensor):
534
+ self.mem_state[indices] = MemoryStateInt.SYNCED
535
+
536
+ @synchronized
537
+ def protect_write(self, indices: torch.Tensor):
538
+ assert self.is_reserved(indices), (
539
+ f"The host memory slots should be RESERVED before write operations. "
540
+ f"Current state: {self.get_state(indices)}"
541
+ )
542
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
543
+
544
+ @synchronized
545
+ def protect_load(self, indices: torch.Tensor):
546
+ assert self.is_backup(indices), (
547
+ f"The host memory slots should be in BACKUP state before load operations. "
548
+ f"Current state: {self.get_state(indices)}"
549
+ )
550
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
551
+
552
+ @synchronized
553
+ def complete_io(self, indices: torch.Tensor):
554
+ assert self.is_protected(indices), (
555
+ f"The host memory slots should be PROTECTED during I/O operations. "
556
+ f"Current state: {self.get_state(indices)}"
557
+ )
558
+ self.mem_state[indices] = MemoryStateInt.SYNCED
559
+
560
+ def available_size(self):
561
+ return len(self.free_slots)
562
+
563
+ @synchronized
564
+ def free(self, indices: torch.Tensor) -> int:
565
+ self.mem_state[indices] = MemoryStateInt.IDLE
566
+ self.free_slots = torch.concat([self.free_slots, indices])
567
+ self.can_use_mem_size += len(indices)
568
+ return len(indices)
@@ -114,26 +114,20 @@ class TokenizerMetricsCollector:
114
114
  documentation="Histogram of time to first token in seconds.",
115
115
  labelnames=labels.keys(),
116
116
  buckets=[
117
- 0.001,
118
- 0.005,
119
- 0.01,
120
- 0.02,
121
- 0.04,
122
- 0.06,
123
- 0.08,
124
117
  0.1,
125
118
  0.25,
126
119
  0.5,
127
120
  0.75,
128
- 1.0,
129
- 2.5,
130
- 5.0,
131
- 7.5,
132
- 10.0,
133
- 15.0,
134
- 20.0,
135
- 25.0,
136
- 30.0,
121
+ 1,
122
+ 2,
123
+ 5,
124
+ 10,
125
+ 20,
126
+ 40,
127
+ 60,
128
+ 80,
129
+ 120,
130
+ 160,
137
131
  ],
138
132
  )
139
133
 
@@ -168,21 +162,19 @@ class TokenizerMetricsCollector:
168
162
  documentation="Histogram of End-to-end request latency in seconds",
169
163
  labelnames=labels.keys(),
170
164
  buckets=[
171
- 0.3,
165
+ 0.1,
166
+ 0.25,
172
167
  0.5,
173
- 0.8,
174
- 1.0,
175
- 1.5,
176
- 2.0,
177
- 2.5,
178
- 5.0,
179
- 10.0,
180
- 15.0,
181
- 20.0,
182
- 30.0,
183
- 40.0,
184
- 50.0,
185
- 60.0,
168
+ 1,
169
+ 2,
170
+ 5,
171
+ 10,
172
+ 20,
173
+ 40,
174
+ 60,
175
+ 80,
176
+ 120,
177
+ 160,
186
178
  ],
187
179
  )
188
180