sglang 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -53,7 +53,7 @@ class ModelTpServer:
53
53
  tp_rank: int,
54
54
  server_args: ServerArgs,
55
55
  model_port_args: ModelPortArgs,
56
- model_overide_args,
56
+ model_overide_args: dict,
57
57
  ):
58
58
  server_args, model_port_args = obtain(server_args), obtain(model_port_args)
59
59
  suppress_other_loggers()
@@ -98,7 +98,7 @@ class ModelTpServer:
98
98
  )
99
99
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
100
100
  self.max_prefill_tokens = (
101
- 4096
101
+ 16384
102
102
  if server_args.max_prefill_tokens is None
103
103
  else server_args.max_prefill_tokens
104
104
  )
@@ -178,7 +178,7 @@ class ModelTpServer:
178
178
  self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
179
179
 
180
180
  def exposed_step(self, recv_reqs):
181
- if self.tp_size * self.dp_size != 1:
181
+ if not isinstance(recv_reqs, list):
182
182
  recv_reqs = obtain(recv_reqs)
183
183
 
184
184
  try:
@@ -206,11 +206,11 @@ class ModelTpServer:
206
206
 
207
207
  @torch.inference_mode()
208
208
  def forward_step(self):
209
- new_batch = self.get_new_fill_batch()
209
+ new_batch = self.get_new_prefill_batch()
210
210
 
211
211
  if new_batch is not None:
212
- # Run a new fill batch
213
- self.forward_fill_batch(new_batch)
212
+ # Run a new prefill batch
213
+ self.forward_prefill_batch(new_batch)
214
214
  self.cache_filled_batch(new_batch)
215
215
 
216
216
  if not new_batch.is_empty():
@@ -219,33 +219,32 @@ class ModelTpServer:
219
219
  else:
220
220
  self.running_batch.merge(new_batch)
221
221
  else:
222
- # Run decode batch
222
+ # Run a decode batch
223
223
  if self.running_batch is not None:
224
224
  # Run a few decode batches continuously for reducing overhead
225
- for _ in range(10):
225
+ for _ in range(global_config.num_continue_decode_steps):
226
226
  self.num_generated_tokens += len(self.running_batch.reqs)
227
227
  self.forward_decode_batch(self.running_batch)
228
228
 
229
229
  # Print stats
230
- if self.tp_rank == 0:
231
- if self.decode_forward_ct % 40 == 0:
232
- num_used = self.max_total_num_tokens - (
233
- self.token_to_kv_pool.available_size()
234
- + self.tree_cache.evictable_size()
235
- )
236
- throughput = self.num_generated_tokens / (
237
- time.time() - self.last_stats_tic
238
- )
239
- self.num_generated_tokens = 0
240
- self.last_stats_tic = time.time()
241
- logger.info(
242
- f"[gpu_id={self.gpu_id}] Decode batch. "
243
- f"#running-req: {len(self.running_batch.reqs)}, "
244
- f"#token: {num_used}, "
245
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
246
- f"gen throughput (token/s): {throughput:.2f}, "
247
- f"#queue-req: {len(self.forward_queue)}"
248
- )
230
+ if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
231
+ num_used = self.max_total_num_tokens - (
232
+ self.token_to_kv_pool.available_size()
233
+ + self.tree_cache.evictable_size()
234
+ )
235
+ throughput = self.num_generated_tokens / (
236
+ time.time() - self.last_stats_tic
237
+ )
238
+ self.num_generated_tokens = 0
239
+ self.last_stats_tic = time.time()
240
+ logger.info(
241
+ f"[gpu_id={self.gpu_id}] Decode batch. "
242
+ f"#running-req: {len(self.running_batch.reqs)}, "
243
+ f"#token: {num_used}, "
244
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
245
+ f"gen throughput (token/s): {throughput:.2f}, "
246
+ f"#queue-req: {len(self.forward_queue)}"
247
+ )
249
248
 
250
249
  if self.running_batch.is_empty():
251
250
  self.running_batch = None
@@ -313,12 +312,12 @@ class ModelTpServer:
313
312
  )
314
313
  self.forward_queue.append(req)
315
314
 
316
- def get_new_fill_batch(self) -> Optional[Batch]:
317
- if (
318
- self.running_batch is not None
319
- and len(self.running_batch.reqs) > self.max_running_requests
320
- ):
321
- return None
315
+ def get_new_prefill_batch(self) -> Optional[Batch]:
316
+ running_bs = (
317
+ len(self.running_batch.reqs) if self.running_batch is not None else 0
318
+ )
319
+ if running_bs >= self.max_running_requests:
320
+ return
322
321
 
323
322
  # Compute matched prefix length
324
323
  for req in self.forward_queue:
@@ -344,7 +343,7 @@ class ModelTpServer:
344
343
  if self.running_batch:
345
344
  available_size -= sum(
346
345
  [
347
- (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
346
+ (r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
348
347
  for r in self.running_batch.reqs
349
348
  ]
350
349
  )
@@ -358,7 +357,7 @@ class ModelTpServer:
358
357
  req.prefix_indices = req.prefix_indices[:-delta]
359
358
  if req.image_offset is not None:
360
359
  req.image_offset += delta
361
- if req.extend_input_len == 0 and req.max_new_tokens() > 0:
360
+ if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
362
361
  # Need at least one token to compute logits
363
362
  req.extend_input_len = 1
364
363
  req.prefix_indices = req.prefix_indices[:-1]
@@ -366,7 +365,7 @@ class ModelTpServer:
366
365
  req.image_offset += 1
367
366
 
368
367
  if (
369
- req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
368
+ req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
370
369
  < available_size
371
370
  and (
372
371
  req.extend_input_len + new_batch_input_tokens
@@ -378,7 +377,7 @@ class ModelTpServer:
378
377
  available_size += delta
379
378
 
380
379
  if not (
381
- req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
380
+ req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
382
381
  < available_size
383
382
  ):
384
383
  # Undo locking
@@ -389,19 +388,20 @@ class ModelTpServer:
389
388
  # Add this request to the running batch
390
389
  can_run_list.append(req)
391
390
  new_batch_total_tokens += (
392
- req.extend_input_len + req.max_new_tokens()
391
+ req.extend_input_len + req.sampling_params.max_new_tokens
393
392
  )
394
393
  new_batch_input_tokens += req.extend_input_len
395
394
  else:
396
395
  break
396
+
397
+ if running_bs + len(can_run_list) >= self.max_running_requests:
398
+ break
399
+
397
400
  if len(can_run_list) == 0:
398
401
  return None
399
402
 
400
403
  # Print stats
401
404
  if self.tp_rank == 0:
402
- running_req = (
403
- 0 if self.running_batch is None else len(self.running_batch.reqs)
404
- )
405
405
  hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
406
406
  self.tree_cache_metrics["total"] += (
407
407
  hit_tokens + new_batch_input_tokens
@@ -416,7 +416,7 @@ class ModelTpServer:
416
416
  f"#new-token: {new_batch_input_tokens}, "
417
417
  f"#cached-token: {hit_tokens}, "
418
418
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
419
- f"#running-req: {running_req}, "
419
+ f"#running-req: {running_bs}, "
420
420
  f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
421
421
  )
422
422
  # logger.debug(
@@ -436,7 +436,7 @@ class ModelTpServer:
436
436
  self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
437
437
  return new_batch
438
438
 
439
- def forward_fill_batch(self, batch: Batch):
439
+ def forward_prefill_batch(self, batch: Batch):
440
440
  # Build batch tensors
441
441
  batch.prepare_for_extend(
442
442
  self.model_config.vocab_size, self.int_token_logit_bias
@@ -746,8 +746,8 @@ class ModelTpClient:
746
746
  # Init model
747
747
  assert len(gpu_ids) == 1
748
748
  self.model_server = ModelTpService().exposed_ModelTpServer(
749
- 0,
750
749
  gpu_ids[0],
750
+ 0,
751
751
  server_args,
752
752
  model_port_args,
753
753
  model_overide_args,
sglang/srt/memory_pool.py CHANGED
@@ -8,96 +8,98 @@ logger = logging.getLogger(__name__)
8
8
 
9
9
 
10
10
  class ReqToTokenPool:
11
- def __init__(self, size, max_context_len):
11
+ """A memory pool that maps a request to its token locations."""
12
+
13
+ def __init__(self, size: int, max_context_len: int):
12
14
  self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
13
- self.can_use_mem_size = size
14
15
  self.req_to_token = torch.empty(
15
16
  (size, max_context_len), dtype=torch.int32, device="cuda"
16
17
  )
18
+ self.can_use_mem_size = size
17
19
 
18
- def alloc(self, need_size):
20
+ def alloc(self, need_size: int):
19
21
  if need_size > self.can_use_mem_size:
20
22
  return None
21
23
 
22
- select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
23
- self.mem_state[select_index] = 0
24
+ select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
25
+ self.mem_state[select_index] = False
24
26
  self.can_use_mem_size -= need_size
25
- return select_index.to(torch.int32)
26
27
 
27
- def free(self, free_index):
28
+ return select_index
29
+
30
+ def free(self, free_index: int):
31
+ self.mem_state[free_index] = True
28
32
  if isinstance(free_index, (int,)):
29
33
  self.can_use_mem_size += 1
30
34
  else:
31
35
  self.can_use_mem_size += free_index.shape[0]
32
- self.mem_state[free_index] = 1
33
36
 
34
37
  def clear(self):
35
- self.mem_state.fill_(1)
38
+ self.mem_state.fill_(True)
36
39
  self.can_use_mem_size = len(self.mem_state)
37
40
 
38
41
 
39
42
  class TokenToKVPool:
43
+ """A memory pool that maps a token to its kv cache locations"""
44
+
40
45
  def __init__(self, size, dtype, head_num, head_dim, layer_num):
41
- self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
42
- self.total_ref_ct = 0
46
+ self.size = size
47
+
48
+ # We also add one slot. This slot is used for writing dummy output from padded tokens.
49
+ self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
43
50
 
44
51
  # [size, key/value, head_num, head_dim] for each layer
45
52
  self.kv_data = [
46
- torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
53
+ torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
47
54
  for _ in range(layer_num)
48
55
  ]
49
56
 
57
+ # Prefetch buffer
58
+ self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
59
+ self.prefetch_chunk_size = 512
60
+
61
+ self.can_use_mem_size = self.size
62
+ self.clear()
63
+
50
64
  def get_key_buffer(self, layer_id):
51
65
  return self.kv_data[layer_id][:, 0]
52
66
 
53
67
  def get_value_buffer(self, layer_id):
54
68
  return self.kv_data[layer_id][:, 1]
55
69
 
70
+ def available_size(self):
71
+ return self.can_use_mem_size + len(self.prefetch_buffer)
72
+
56
73
  def alloc(self, need_size):
57
- select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
58
- if select_index.shape[0] < need_size:
59
- return None
74
+ buffer_len = len(self.prefetch_buffer)
75
+ if need_size <= buffer_len:
76
+ select_index = self.prefetch_buffer[:need_size]
77
+ self.prefetch_buffer = self.prefetch_buffer[need_size:]
78
+ return select_index
60
79
 
61
- self.add_refs(select_index)
62
- return select_index.to(torch.int32)
80
+ addition_size = need_size - buffer_len
81
+ alloc_size = max(addition_size, self.prefetch_chunk_size)
82
+ select_index = torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
63
83
 
64
- def alloc_contiguous(self, need_size):
65
- empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
66
- if empty_index.shape[0] < need_size:
67
- return None
68
- empty_size = len(empty_index)
69
- loc_sum = (
70
- empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)]
71
- )
72
- can_used_loc = empty_index[: empty_size - (need_size - 1)][
73
- loc_sum == need_size - 1
74
- ]
75
- if can_used_loc.shape[0] == 0:
84
+ if select_index.shape[0] < addition_size:
76
85
  return None
77
86
 
78
- start_loc = can_used_loc[0].item()
79
- select_index = torch.arange(start_loc, start_loc + need_size, device="cuda")
80
- self.add_refs(select_index)
81
- return select_index.to(torch.int32), start_loc, start_loc + need_size
87
+ self.mem_state[select_index] = False
88
+ self.can_use_mem_size -= len(select_index)
82
89
 
83
- def used_size(self):
84
- return len(torch.nonzero(self.mem_state).squeeze(1))
90
+ self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
91
+ ret_index = self.prefetch_buffer[:need_size]
92
+ self.prefetch_buffer = self.prefetch_buffer[need_size:]
85
93
 
86
- def available_size(self):
87
- return torch.sum(self.mem_state == 0).item()
88
-
89
- def add_refs(self, token_index: torch.Tensor):
90
- self.total_ref_ct += len(token_index)
91
- self.mem_state[token_index] += 1
92
-
93
- def dec_refs(self, token_index: torch.Tensor):
94
- self.total_ref_ct -= len(token_index)
95
- self.mem_state[token_index] -= 1
94
+ return ret_index
96
95
 
97
- num_freed = torch.sum(self.mem_state[token_index] == 0)
98
-
99
- return num_freed
96
+ def free(self, free_index: torch.Tensor):
97
+ self.mem_state[free_index] = True
98
+ self.can_use_mem_size += len(free_index)
100
99
 
101
100
  def clear(self):
102
- self.mem_state.fill_(0)
103
- self.total_ref_ct = 0
101
+ self.mem_state.fill_(True)
102
+ self.can_use_mem_size = self.size
103
+
104
+ # We also add one slot. This slot is used for writing dummy output from padded tokens.
105
+ self.mem_state[0] = False
@@ -5,12 +5,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  from torch import nn
8
-
9
8
  from vllm.config import CacheConfig
10
9
  from vllm.distributed import get_tensor_model_parallel_world_size
11
-
12
10
  from vllm.model_executor.layers.activation import SiluAndMul
13
-
14
11
  from vllm.model_executor.layers.layernorm import RMSNorm
15
12
  from vllm.model_executor.layers.linear import (
16
13
  MergedColumnParallelLinear,
@@ -31,7 +28,6 @@ from sglang.srt.managers.controller.model_runner import InputMetadata
31
28
 
32
29
 
33
30
  class MiniCPMMLP(nn.Module):
34
-
35
31
  def __init__(
36
32
  self,
37
33
  hidden_size: int,
@@ -67,7 +63,6 @@ class MiniCPMMLP(nn.Module):
67
63
 
68
64
 
69
65
  class MiniCPMAttention(nn.Module):
70
-
71
66
  def __init__(
72
67
  self,
73
68
  hidden_size: int,
@@ -152,7 +147,6 @@ class MiniCPMAttention(nn.Module):
152
147
 
153
148
 
154
149
  class MiniCPMDecoderLayer(nn.Module):
155
-
156
150
  def __init__(
157
151
  self,
158
152
  config,
@@ -217,7 +211,6 @@ class MiniCPMDecoderLayer(nn.Module):
217
211
 
218
212
 
219
213
  class MiniCPMModel(nn.Module):
220
-
221
214
  def __init__(
222
215
  self,
223
216
  config,
@@ -274,7 +267,7 @@ class MiniCPMForCausalLM(nn.Module):
274
267
  ) -> None:
275
268
  super().__init__()
276
269
  self.config = config
277
-
270
+
278
271
  self.num_experts = getattr(self.config, "num_experts", 0)
279
272
  self.quant_config = quant_config
280
273
  self.model = MiniCPMModel(config, quant_config=quant_config)