sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sglang/bench_offline_throughput.py +1 -0
  2. sglang/bench_serving.py +11 -3
  3. sglang/lang/backend/openai.py +10 -0
  4. sglang/srt/configs/model_config.py +11 -2
  5. sglang/srt/constrained/xgrammar_backend.py +6 -0
  6. sglang/srt/layers/attention/__init__.py +0 -1
  7. sglang/srt/layers/attention/flashinfer_backend.py +54 -41
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  9. sglang/srt/layers/logits_processor.py +30 -2
  10. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
  11. sglang/srt/layers/moe/topk.py +14 -0
  12. sglang/srt/layers/quantization/fp8.py +42 -2
  13. sglang/srt/layers/quantization/fp8_kernel.py +91 -18
  14. sglang/srt/layers/quantization/fp8_utils.py +8 -2
  15. sglang/srt/managers/io_struct.py +29 -8
  16. sglang/srt/managers/schedule_batch.py +22 -15
  17. sglang/srt/managers/schedule_policy.py +1 -1
  18. sglang/srt/managers/scheduler.py +71 -34
  19. sglang/srt/managers/session_controller.py +102 -27
  20. sglang/srt/managers/tokenizer_manager.py +95 -55
  21. sglang/srt/managers/tp_worker.py +7 -0
  22. sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
  23. sglang/srt/model_executor/forward_batch_info.py +42 -3
  24. sglang/srt/model_executor/model_runner.py +4 -6
  25. sglang/srt/model_loader/loader.py +22 -11
  26. sglang/srt/models/gemma2.py +19 -0
  27. sglang/srt/models/llama.py +13 -2
  28. sglang/srt/models/llama_eagle.py +132 -0
  29. sglang/srt/openai_api/adapter.py +79 -2
  30. sglang/srt/openai_api/protocol.py +50 -0
  31. sglang/srt/sampling/sampling_params.py +9 -2
  32. sglang/srt/server.py +45 -39
  33. sglang/srt/server_args.py +17 -30
  34. sglang/srt/speculative/spec_info.py +19 -0
  35. sglang/srt/utils.py +62 -0
  36. sglang/version.py +1 -1
  37. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
  38. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
  39. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
  40. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
  41. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,17 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
11
11
  import torch
12
12
  import triton
13
13
  import triton.language as tl
14
- from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
15
14
  from vllm import _custom_ops as ops
16
15
 
17
16
  from sglang.srt.layers.moe.topk import select_experts
18
17
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
19
- from sglang.srt.utils import direct_register_custom_op, get_device_name
18
+ from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
19
+
20
+ not_hip = False
21
+ if not is_hip():
22
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
23
+
24
+ not_hip = True
20
25
 
21
26
  logger = logging.getLogger(__name__)
22
27
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@@ -267,8 +272,14 @@ def moe_align_block_size(
267
272
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
268
273
  )
269
274
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
270
- # FIXME(zhyncs)
271
- if num_experts >= 256:
275
+ if not_hip and num_experts >= 224:
276
+ token_cnts_buffer = torch.empty(
277
+ (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
278
+ )
279
+ cumsum_buffer = torch.empty(
280
+ num_experts + 1, dtype=torch.int32, device=topk_ids.device
281
+ )
282
+
272
283
  sgl_moe_align_block_size(
273
284
  topk_ids,
274
285
  num_experts,
@@ -276,6 +287,8 @@ def moe_align_block_size(
276
287
  sorted_ids,
277
288
  expert_ids,
278
289
  num_tokens_post_pad,
290
+ token_cnts_buffer,
291
+ cumsum_buffer,
279
292
  )
280
293
  else:
281
294
  ops.moe_align_block_size(
@@ -379,14 +392,25 @@ def invoke_fused_moe_kernel(
379
392
  )
380
393
 
381
394
 
382
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
395
+ def get_config_file_name(
396
+ E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None
397
+ ) -> str:
383
398
  device_name = get_device_name().replace(" ", "_")
384
399
  dtype_selector = "" if not dtype else f",dtype={dtype}"
385
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
400
+ block_shape_selector = (
401
+ "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
402
+ )
403
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json"
386
404
 
387
405
 
388
406
  @functools.lru_cache
389
- def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
407
+ def get_moe_configs(
408
+ E: int,
409
+ N: int,
410
+ dtype: Optional[str],
411
+ block_n: Optional[int] = 0,
412
+ block_k: Optional[int] = 0,
413
+ ) -> Optional[Dict[int, Any]]:
390
414
  """
391
415
  Return optimized configurations for the fused MoE kernel.
392
416
 
@@ -398,7 +422,7 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
398
422
 
399
423
  # First look up if an optimized configuration is available in the configs
400
424
  # directory
401
- json_file_name = get_config_file_name(E, N, dtype)
425
+ json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k])
402
426
 
403
427
  config_file_path = os.path.join(
404
428
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
@@ -429,25 +453,37 @@ def get_default_config(
429
453
  topk: int,
430
454
  dtype: Optional[str],
431
455
  is_marlin: bool,
456
+ block_shape: Optional[List[int]] = None,
432
457
  ) -> Dict[str, int]:
433
458
  if dtype == "fp8_w8a8":
434
- config = {
435
- "BLOCK_SIZE_M": 128,
436
- "BLOCK_SIZE_N": 256,
437
- "BLOCK_SIZE_K": 128,
438
- "GROUP_SIZE_M": 32,
439
- "num_warps": 8,
440
- "num_stages": 4,
441
- }
442
- if M <= E:
459
+ if block_shape is None:
443
460
  config = {
444
- "BLOCK_SIZE_M": 64,
445
- "BLOCK_SIZE_N": 128,
461
+ "BLOCK_SIZE_M": 128,
462
+ "BLOCK_SIZE_N": 256,
446
463
  "BLOCK_SIZE_K": 128,
447
- "GROUP_SIZE_M": 1,
448
- "num_warps": 4,
464
+ "GROUP_SIZE_M": 32,
465
+ "num_warps": 8,
449
466
  "num_stages": 4,
450
467
  }
468
+ if M <= E:
469
+ config = {
470
+ "BLOCK_SIZE_M": 64,
471
+ "BLOCK_SIZE_N": 128,
472
+ "BLOCK_SIZE_K": 128,
473
+ "GROUP_SIZE_M": 1,
474
+ "num_warps": 4,
475
+ "num_stages": 4,
476
+ }
477
+ else:
478
+ # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
479
+ config = {
480
+ "BLOCK_SIZE_M": 64,
481
+ "BLOCK_SIZE_N": block_shape[0],
482
+ "BLOCK_SIZE_K": block_shape[1],
483
+ "GROUP_SIZE_M": 32,
484
+ "num_warps": 4,
485
+ "num_stages": 3,
486
+ }
451
487
  else:
452
488
  config = {
453
489
  "BLOCK_SIZE_M": 64,
@@ -483,7 +519,9 @@ def try_get_optimal_moe_config(
483
519
  else:
484
520
  # First try to load optimal config from the file
485
521
  E, _, N = w2_shape
486
- configs = get_moe_configs(E, N, dtype)
522
+ block_n = block_shape[0] if block_shape else 0
523
+ block_k = block_shape[1] if block_shape else 0
524
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
487
525
 
488
526
  if configs:
489
527
  # If an optimal configuration map has been found, look up the
@@ -491,14 +529,9 @@ def try_get_optimal_moe_config(
491
529
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
492
530
  else:
493
531
  # Else use the default config
494
- config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
495
- # TODO(HandH1998): Optimize the configs of block-wise quant.
496
- # NOTE(HandH1998): For block-wise quant,
497
- # BLOCK_K must be divisable by block_shape[1]
498
- # BLOCK_N and BLOCK_M has no requirements
499
- if block_shape is not None:
500
- config["BLOCK_SIZE_N"] = block_shape[0]
501
- config["BLOCK_SIZE_K"] = block_shape[1]
532
+ config = get_default_config(
533
+ M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
534
+ )
502
535
  return config
503
536
 
504
537
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
1
15
  from typing import Callable, Optional
2
16
 
3
17
  import torch
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
28
28
  from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
29
29
 
30
30
  from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
31
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
32
31
  from sglang.srt.layers.quantization.base_config import (
33
32
  QuantizationConfig,
34
33
  QuantizeMethodBase,
@@ -273,6 +272,19 @@ class Fp8LinearMethod(LinearMethodBase):
273
272
  def process_weights_after_loading(self, layer: Module) -> None:
274
273
  # Block quant doesn't need to process weights after loading
275
274
  if self.block_quant:
275
+ # If ROCm, normalize the weights and scales to e4m3fnuz
276
+ if is_hip():
277
+ # activation_scheme: dynamic
278
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
279
+ weight=layer.weight,
280
+ weight_scale=layer.weight_scale_inv,
281
+ input_scale=None,
282
+ )
283
+ layer.weight = torch.nn.Parameter(weight, require_grad=False)
284
+ layer.weight_scale_inv = torch.nn.Parameter(
285
+ weight_scale, require_grad=False
286
+ )
287
+ layer.input_scale = None
276
288
  return
277
289
  layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
278
290
  # If checkpoint not serialized fp8, quantize the weights.
@@ -370,7 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
370
382
  weight=layer.weight,
371
383
  block_size=self.quant_config.weight_block_size,
372
384
  weight_scale=layer.weight_scale_inv,
373
- input_scale=layer.input_scale,
385
+ input_scale=None,
374
386
  bias=bias,
375
387
  )
376
388
 
@@ -548,8 +560,36 @@ class Fp8MoEMethod:
548
560
  layer.w2_input_scale = None
549
561
 
550
562
  def process_weights_after_loading(self, layer: Module) -> None:
563
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
564
+ padding_size, # Avoid circular import
565
+ )
566
+
551
567
  # Block quant doesn't need to process weights after loading
552
568
  if self.block_quant:
569
+ # If ROCm, normalize the weights and scales to e4m3fnuz
570
+ if is_hip():
571
+ # activation_scheme: dynamic
572
+ w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
573
+ weight=layer.w13_weight,
574
+ weight_scale=layer.w13_weight_scale_inv,
575
+ input_scale=None,
576
+ )
577
+ w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
578
+ weight=layer.w2_weight,
579
+ weight_scale=layer.w2_weight_scale_inv,
580
+ input_scale=None,
581
+ )
582
+ # Reset the parameter
583
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
584
+ layer.w13_weight_scale_inv = torch.nn.Parameter(
585
+ w13_weight_scale, requires_grad=False
586
+ )
587
+ layer.w13_input_scale = None
588
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
589
+ layer.w2_weight_scale_inv = torch.nn.Parameter(
590
+ w2_weight_scale, requires_grad=False
591
+ )
592
+ layer.w2_input_scale = None
553
593
  return
554
594
  # If checkpoint is fp16 or bfloat16, quantize in place.
555
595
  if not self.quant_config.is_checkpoint_fp8_serialized:
@@ -1,9 +1,34 @@
1
- from typing import List, Tuple
1
+ # Copyright 2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ import functools
16
+ import json
17
+ import logging
18
+ import os
19
+ from typing import Any, Dict, List, Optional, Tuple
2
20
 
3
21
  import torch
4
22
  import triton
5
23
  import triton.language as tl
6
24
 
25
+ from sglang.srt.utils import get_device_name, is_hip
26
+
27
+ is_hip_ = is_hip()
28
+ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
29
+
30
+ logger = logging.getLogger(__name__)
31
+
7
32
 
8
33
  @triton.jit
9
34
  def _per_token_group_quant_fp8(
@@ -51,7 +76,7 @@ def per_token_group_quant_fp8(
51
76
  x: torch.Tensor,
52
77
  group_size: int,
53
78
  eps: float = 1e-10,
54
- dtype: torch.dtype = torch.float8_e4m3fn,
79
+ dtype: torch.dtype = fp8_type_,
55
80
  ) -> Tuple[torch.Tensor, torch.Tensor]:
56
81
  """Function to perform per-token-group quantization on an input tensor `x`.
57
82
 
@@ -73,9 +98,13 @@ def per_token_group_quant_fp8(
73
98
  assert x.is_contiguous(), "`x` is not contiguous"
74
99
 
75
100
  finfo = torch.finfo(dtype)
76
- fp8_min = finfo.min
77
101
  fp8_max = finfo.max
78
102
 
103
+ if is_hip_:
104
+ fp8_max = 224.0
105
+
106
+ fp8_min = -fp8_max
107
+
79
108
  x_q = torch.empty_like(x, device=x.device, dtype=dtype)
80
109
  M = x.numel() // group_size
81
110
  N = group_size
@@ -191,6 +220,48 @@ def _w8a8_block_fp8_matmul(
191
220
  tl.store(c_ptrs, c, mask=c_mask)
192
221
 
193
222
 
223
+ @functools.lru_cache
224
+ def get_w8a8_block_fp8_configs(
225
+ N: int, K: int, block_n: int, block_k: int
226
+ ) -> Optional[Dict[int, Any]]:
227
+ """
228
+ Return optimized configurations for the w8a8 block fp8 kernel.
229
+
230
+ The return value will be a dictionary that maps an irregular grid of
231
+ batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
232
+ kernel on a given batch size bs, the closest batch size in the grid should
233
+ be picked and the associated configuration chosen to invoke the kernel.
234
+ """
235
+
236
+ # First look up if an optimized configuration is available in the configs
237
+ # directory
238
+ device_name = get_device_name().replace(" ", "_")
239
+ json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
240
+
241
+ config_file_path = os.path.join(
242
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
243
+ )
244
+ if os.path.exists(config_file_path):
245
+ with open(config_file_path) as f:
246
+ logger.info(
247
+ "Using configuration from %s for W8A8 Block FP8 kernel.",
248
+ config_file_path,
249
+ )
250
+ # If a configuration has been found, return it
251
+ return {int(key): val for key, val in json.load(f).items()}
252
+
253
+ # If no optimized configuration is available, we will use the default
254
+ # configuration
255
+ logger.warning(
256
+ (
257
+ "Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
258
+ "Config file not found at %s"
259
+ ),
260
+ config_file_path,
261
+ )
262
+ return None
263
+
264
+
194
265
  def w8a8_block_fp8_matmul(
195
266
  A: torch.Tensor,
196
267
  B: torch.Tensor,
@@ -231,17 +302,22 @@ def w8a8_block_fp8_matmul(
231
302
  C_shape = A.shape[:-1] + (N,)
232
303
  C = A.new_empty(C_shape, dtype=output_dtype)
233
304
 
234
- # TODO(HandH1998):
235
- # BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
236
- # BLOCK_SIZE_K must be divisable by block_k
237
- # BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
238
- BLOCK_SIZE_M = 128
239
- if M < BLOCK_SIZE_M:
240
- BLOCK_SIZE_M = triton.next_power_of_2(M)
241
- BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
242
- BLOCK_SIZE_K = block_k
243
- assert block_k % BLOCK_SIZE_K == 0
244
- BLOCK_SIZE_N = block_n
305
+ configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
306
+ if configs:
307
+ # If an optimal configuration map has been found, look up the
308
+ # optimal config
309
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
310
+ else:
311
+ # Default config
312
+ # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
313
+ config = {
314
+ "BLOCK_SIZE_M": 64,
315
+ "BLOCK_SIZE_N": block_size[0],
316
+ "BLOCK_SIZE_K": block_size[1],
317
+ "GROUP_SIZE_M": 32,
318
+ "num_warps": 4,
319
+ "num_stages": 3,
320
+ }
245
321
 
246
322
  def grid(META):
247
323
  return (
@@ -269,10 +345,7 @@ def w8a8_block_fp8_matmul(
269
345
  As.stride(-1),
270
346
  Bs.stride(1),
271
347
  Bs.stride(0),
272
- BLOCK_SIZE_M=BLOCK_SIZE_M,
273
- BLOCK_SIZE_N=BLOCK_SIZE_N,
274
- BLOCK_SIZE_K=BLOCK_SIZE_K,
275
- GROUP_SIZE_M=8,
348
+ **config,
276
349
  )
277
350
 
278
351
  return C
@@ -7,6 +7,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
7
7
  per_token_group_quant_fp8,
8
8
  w8a8_block_fp8_matmul,
9
9
  )
10
+ from sglang.srt.utils import is_hip
11
+
12
+ is_hip_ = is_hip()
10
13
 
11
14
 
12
15
  def normalize_e4m3fn_to_e4m3fnuz(
@@ -63,8 +66,11 @@ def input_to_float8(
63
66
  finfo = torch.finfo(dtype)
64
67
  min_val, max_val = x.aminmax()
65
68
  amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
66
- scale = finfo.max / amax
67
- x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
69
+ fp8_max = finfo.max
70
+ if is_hip_:
71
+ fp8_max = 224.0
72
+ scale = fp8_max / amax
73
+ x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
68
74
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
69
75
 
70
76
 
@@ -21,10 +21,20 @@ from dataclasses import dataclass
21
21
  from enum import Enum
22
22
  from typing import Dict, List, Optional, Tuple, Union
23
23
 
24
+ import torch
25
+
24
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
25
27
  from sglang.srt.sampling.sampling_params import SamplingParams
26
28
 
27
29
 
30
+ @dataclass
31
+ class SessionParams:
32
+ id: Optional[str] = None
33
+ rid: Optional[str] = None
34
+ offset: Optional[int] = None
35
+ replace: Optional[bool] = None
36
+
37
+
28
38
  @dataclass
29
39
  class GenerateReqInput:
30
40
  # The input prompt. It can be a single prompt or a batch of prompts.
@@ -56,10 +66,8 @@ class GenerateReqInput:
56
66
  # LoRA related
57
67
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
58
68
 
59
- # Session id info for continual prompting
60
- session: Optional[
61
- Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
62
- ] = None
69
+ # Session info for continual prompting
70
+ session_params: Optional[Union[List[Dict], Dict]] = None
63
71
 
64
72
  def normalize_batch_and_arguments(self):
65
73
  if (
@@ -221,9 +229,8 @@ class TokenizedGenerateReqInput:
221
229
  # The input embeds
222
230
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
223
231
 
224
- # Session id info for continual prompting
225
- session_id: Optional[str] = None
226
- session_rid: Optional[str] = None
232
+ # Session info for continual prompting
233
+ session_params: Optional[SessionParams] = None
227
234
 
228
235
 
229
236
  @dataclass
@@ -407,6 +414,18 @@ class UpdateWeightsFromDistributedReqOutput:
407
414
  message: str
408
415
 
409
416
 
417
+ @dataclass
418
+ class UpdateWeightsFromTensorReqInput:
419
+ name: str
420
+ tensor: torch.Tensor
421
+
422
+
423
+ @dataclass
424
+ class UpdateWeightsFromTensorReqOutput:
425
+ success: bool
426
+ message: str
427
+
428
+
410
429
  @dataclass
411
430
  class InitWeightsUpdateGroupReqInput:
412
431
  # The master address
@@ -454,6 +473,7 @@ class ProfileReq(Enum):
454
473
  @dataclass
455
474
  class OpenSessionReqInput:
456
475
  capacity_of_str_len: int
476
+ session_id: Optional[str] = None
457
477
 
458
478
 
459
479
  @dataclass
@@ -463,4 +483,5 @@ class CloseSessionReqInput:
463
483
 
464
484
  @dataclass
465
485
  class OpenSessionReqOutput:
466
- session_id: str
486
+ session_id: Optional[str]
487
+ success: bool
@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
29
29
 
30
30
  import dataclasses
31
31
  import logging
32
- from typing import List, Optional, Tuple, Union
32
+ from typing import List, Optional, Set, Tuple, Union
33
33
 
34
34
  import numpy as np
35
35
  import torch
@@ -209,6 +209,7 @@ class Req:
209
209
  lora_path: Optional[str] = None,
210
210
  input_embeds: Optional[List[List[float]]] = None,
211
211
  session_id: Optional[str] = None,
212
+ eos_token_ids: Optional[Set[int]] = None,
212
213
  ):
213
214
  # Input and output info
214
215
  self.rid = rid
@@ -236,6 +237,7 @@ class Req:
236
237
  self.finished_reason = None
237
238
  self.to_abort = False
238
239
  self.stream = stream
240
+ self.eos_token_ids = eos_token_ids
239
241
 
240
242
  # For incremental decoding
241
243
  # ----- | --------- read_ids -------|
@@ -395,18 +397,23 @@ class Req:
395
397
 
396
398
  last_token_id = self.output_ids[-1]
397
399
 
398
- matched_eos = False
399
-
400
- # Check stop token ids
401
- if self.sampling_params.stop_token_ids:
402
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
403
- if self.tokenizer is not None:
404
- matched_eos |= last_token_id == self.tokenizer.eos_token_id
405
- if self.tokenizer.additional_stop_token_ids:
406
- matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
407
- if matched_eos and not self.sampling_params.ignore_eos:
408
- self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
409
- return
400
+ if not self.sampling_params.ignore_eos:
401
+ matched_eos = False
402
+
403
+ # Check stop token ids
404
+ if self.sampling_params.stop_token_ids:
405
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
406
+ if self.eos_token_ids:
407
+ matched_eos |= last_token_id in self.eos_token_ids
408
+ if self.tokenizer is not None:
409
+ matched_eos |= last_token_id == self.tokenizer.eos_token_id
410
+ if self.tokenizer.additional_stop_token_ids:
411
+ matched_eos |= (
412
+ last_token_id in self.tokenizer.additional_stop_token_ids
413
+ )
414
+ if matched_eos:
415
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
416
+ return
410
417
 
411
418
  # Check stop strings
412
419
  if len(self.sampling_params.stop_strs) > 0:
@@ -836,8 +843,8 @@ class ScheduleBatch:
836
843
  # TODO (lianmin): Revisit this. It should be seq_len - 1
837
844
  self.extend_logprob_start_lens.extend([0] * running_bs)
838
845
 
839
- def check_decode_mem(self):
840
- bs = len(self.reqs)
846
+ def check_decode_mem(self, buf_multiplier=1):
847
+ bs = len(self.reqs) * buf_multiplier
841
848
  if self.token_to_kv_pool.available_size() >= bs:
842
849
  return True
843
850
 
@@ -248,7 +248,7 @@ class PrefillAdder:
248
248
  self.can_run_list.append(req)
249
249
 
250
250
  self._prefill_one_req(
251
- len(req.prefix_indices),
251
+ 0,
252
252
  req.extend_input_len,
253
253
  (
254
254
  min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)