sglang 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. sglang/__init__.py +33 -26
  2. sglang/api.py +9 -1
  3. sglang/bench_latency.py +2 -2
  4. sglang/bench_serving.py +10 -1
  5. sglang/check_env.py +1 -1
  6. sglang/lang/backend/litellm.py +1 -1
  7. sglang/lang/backend/openai.py +1 -1
  8. sglang/lang/interpreter.py +21 -5
  9. sglang/lang/ir.py +1 -2
  10. sglang/srt/constrained/__init__.py +15 -0
  11. sglang/srt/constrained/{base_cache.py → base_tool_cache.py} +17 -2
  12. sglang/srt/constrained/fsm_cache.py +17 -2
  13. sglang/srt/constrained/jump_forward.py +17 -2
  14. sglang/srt/conversation.py +26 -0
  15. sglang/srt/hf_transformers_utils.py +15 -0
  16. sglang/srt/layers/context_flashattention_nopad.py +15 -0
  17. sglang/srt/layers/extend_attention.py +15 -0
  18. sglang/srt/layers/fused_moe.py +15 -0
  19. sglang/srt/layers/linear.py +15 -0
  20. sglang/srt/layers/logits_processor.py +41 -13
  21. sglang/srt/layers/quantization/__init__.py +15 -0
  22. sglang/srt/layers/quantization/fp8.py +15 -0
  23. sglang/srt/layers/radix_attention.py +17 -2
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
  26. sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
  27. sglang/srt/managers/detokenizer_manager.py +16 -1
  28. sglang/srt/managers/io_struct.py +36 -3
  29. sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
  30. sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +60 -21
  31. sglang/srt/managers/tokenizer_manager.py +39 -16
  32. sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +159 -46
  33. sglang/srt/mem_cache/base_cache.py +43 -0
  34. sglang/srt/mem_cache/chunk_cache.py +60 -0
  35. sglang/srt/mem_cache/flush_cache.py +33 -0
  36. sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
  37. sglang/srt/{managers/controller → mem_cache}/radix_cache.py +20 -2
  38. sglang/srt/mm_utils.py +15 -0
  39. sglang/srt/model_config.py +15 -0
  40. sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +16 -1
  41. sglang/srt/{managers/controller → model_executor}/model_runner.py +49 -14
  42. sglang/srt/model_loader/model_loader.py +15 -0
  43. sglang/srt/model_loader/utils.py +16 -1
  44. sglang/srt/models/chatglm.py +16 -1
  45. sglang/srt/models/commandr.py +16 -1
  46. sglang/srt/models/dbrx.py +16 -1
  47. sglang/srt/models/deepseek.py +16 -1
  48. sglang/srt/models/deepseek_v2.py +16 -1
  49. sglang/srt/models/gemma.py +16 -1
  50. sglang/srt/models/gemma2.py +16 -1
  51. sglang/srt/models/gpt_bigcode.py +16 -1
  52. sglang/srt/models/grok.py +16 -1
  53. sglang/srt/models/internlm2.py +16 -1
  54. sglang/srt/models/llama2.py +21 -22
  55. sglang/srt/models/llama_classification.py +16 -1
  56. sglang/srt/models/llava.py +17 -2
  57. sglang/srt/models/llavavid.py +17 -2
  58. sglang/srt/models/minicpm.py +16 -1
  59. sglang/srt/models/mistral.py +15 -0
  60. sglang/srt/models/mixtral.py +16 -1
  61. sglang/srt/models/mixtral_quant.py +16 -1
  62. sglang/srt/models/qwen.py +16 -1
  63. sglang/srt/models/qwen2.py +16 -1
  64. sglang/srt/models/qwen2_moe.py +16 -1
  65. sglang/srt/models/stablelm.py +16 -1
  66. sglang/srt/models/yivl.py +15 -0
  67. sglang/srt/openai_api/adapter.py +569 -131
  68. sglang/srt/openai_api/protocol.py +84 -2
  69. sglang/srt/sampling_params.py +15 -0
  70. sglang/srt/server.py +92 -23
  71. sglang/srt/server_args.py +52 -11
  72. sglang/srt/utils.py +15 -0
  73. sglang/test/test_programs.py +9 -6
  74. sglang/utils.py +22 -0
  75. sglang/version.py +1 -1
  76. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/METADATA +33 -7
  77. sglang-0.2.8.dist-info/RECORD +95 -0
  78. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/WHEEL +1 -1
  79. sglang/srt/flush_cache.py +0 -18
  80. sglang-0.2.6.dist-info/RECORD +0 -93
  81. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/LICENSE +0 -0
  82. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # Adapted from
2
17
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
3
18
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
@@ -5,7 +20,7 @@ import torch
5
20
  import triton
6
21
  import triton.language as tl
7
22
 
8
- from sglang.srt.managers.controller.infer_batch import global_server_args_dict
23
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
9
24
 
10
25
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
11
26
  REDUCE_TRITON_TYPE = tl.float32
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """
2
17
  A controller that manages multiple data parallel workers.
3
18
  Each data parallel worker can manage multiple tensor parallel workers.
@@ -12,7 +27,7 @@ from enum import Enum, auto
12
27
  import numpy as np
13
28
  import zmq
14
29
 
15
- from sglang.srt.managers.controller.manager_single import (
30
+ from sglang.srt.managers.controller_single import (
16
31
  start_controller_process as start_controller_process_single,
17
32
  )
18
33
  from sglang.srt.managers.io_struct import (
@@ -24,7 +39,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
24
39
  from sglang.srt.utils import kill_parent_process
25
40
  from sglang.utils import get_exception_traceback
26
41
 
27
- logger = logging.getLogger("srt.controller")
42
+ logger = logging.getLogger(__name__)
28
43
 
29
44
 
30
45
  class LoadBalanceMethod(Enum):
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """A controller that manages a group of tensor parallel workers."""
2
17
 
3
18
  import logging
@@ -7,7 +22,7 @@ from typing import List
7
22
 
8
23
  import zmq
9
24
 
10
- from sglang.srt.managers.controller.tp_worker import (
25
+ from sglang.srt.managers.tp_worker import (
11
26
  ModelTpServer,
12
27
  broadcast_recv_input,
13
28
  launch_tp_servers,
@@ -16,7 +31,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
16
31
  from sglang.srt.utils import kill_parent_process
17
32
  from sglang.utils import get_exception_traceback
18
33
 
19
- logger = logging.getLogger("srt.controller")
34
+ logger = logging.getLogger(__name__)
20
35
 
21
36
 
22
37
  class ControllerSingle:
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """DetokenizerManager is a process that detokenizes the token ids."""
2
17
 
3
18
  import asyncio
@@ -10,8 +25,8 @@ import zmq
10
25
  import zmq.asyncio
11
26
 
12
27
  from sglang.srt.hf_transformers_utils import get_tokenizer
13
- from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
14
28
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
29
+ from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
15
30
  from sglang.srt.server_args import PortArgs, ServerArgs
16
31
  from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
17
32
 
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """
2
17
  The definition of objects transfered between different
3
18
  processes (TokenizerManager, DetokenizerManager, Controller).
@@ -7,7 +22,7 @@ import uuid
7
22
  from dataclasses import dataclass
8
23
  from typing import Dict, List, Optional, Union
9
24
 
10
- from sglang.srt.managers.controller.infer_batch import BaseFinishReason
25
+ from sglang.srt.managers.schedule_batch import BaseFinishReason
11
26
  from sglang.srt.sampling_params import SamplingParams
12
27
 
13
28
 
@@ -64,8 +79,26 @@ class GenerateReqInput:
64
79
  if self.top_logprobs_num is None:
65
80
  self.top_logprobs_num = 0
66
81
  else:
67
-
68
- parallel_sample_num = self.sampling_params.get("n", 1)
82
+ parallel_sample_num_list = []
83
+ if isinstance(self.sampling_params, dict):
84
+ parallel_sample_num = self.sampling_params.get("n", 1)
85
+ elif isinstance(self.sampling_params, list):
86
+ for sp in self.sampling_params:
87
+ parallel_sample_num = sp.get("n", 1)
88
+ parallel_sample_num_list.append(parallel_sample_num)
89
+ parallel_sample_num = max(parallel_sample_num_list)
90
+ all_equal = all(
91
+ element == parallel_sample_num
92
+ for element in parallel_sample_num_list
93
+ )
94
+ if parallel_sample_num > 1 and (not all_equal):
95
+ ## TODO cope with the case that the parallel_sample_num is different for different samples
96
+ raise ValueError(
97
+ "The parallel_sample_num should be the same for all samples in sample params."
98
+ )
99
+ else:
100
+ parallel_sample_num = 1
101
+ self.parallel_sample_num = parallel_sample_num
69
102
 
70
103
  if parallel_sample_num != 1:
71
104
  # parallel sampling +1 represents the original prefill stage
@@ -1,46 +1,61 @@
1
- """Request scheduler heuristic."""
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Request policy scheduler"""
2
17
 
3
18
  import random
4
19
  from collections import defaultdict
5
20
 
6
21
 
7
- class ScheduleHeuristic:
22
+ class PolicyScheduler:
8
23
  def __init__(
9
24
  self,
10
- schedule_heuristic,
25
+ policy,
11
26
  max_running_seqs,
12
27
  max_prefill_num_tokens,
13
28
  max_total_num_tokens,
14
29
  tree_cache,
15
30
  ):
16
- if tree_cache.disable and schedule_heuristic == "lpm":
31
+ if tree_cache.disable and policy == "lpm":
17
32
  # LMP is meaningless when the tree cache is disabled.
18
- schedule_heuristic = "fcfs"
33
+ policy = "fcfs"
19
34
 
20
- self.schedule_heuristic = schedule_heuristic
35
+ self.policy = policy
21
36
  self.max_running_seqs = max_running_seqs
22
37
  self.max_prefill_num_tokens = max_prefill_num_tokens
23
38
  self.max_total_num_tokens = max_total_num_tokens
24
39
  self.tree_cache = tree_cache
25
40
 
26
- def get_priority_queue(self, forward_queue):
27
- if self.schedule_heuristic == "lpm":
41
+ def get_priority_queue(self, waiting_queue):
42
+ if self.policy == "lpm":
28
43
  # longest prefix match
29
- forward_queue.sort(key=lambda x: -len(x.prefix_indices))
30
- return forward_queue
31
- elif self.schedule_heuristic == "fcfs":
44
+ waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
45
+ return waiting_queue
46
+ elif self.policy == "fcfs":
32
47
  # first come first serve
33
- return forward_queue
34
- elif self.schedule_heuristic == "lof":
48
+ return waiting_queue
49
+ elif self.policy == "lof":
35
50
  # longest output first
36
- forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
37
- return forward_queue
38
- elif self.schedule_heuristic == "random":
39
- random.shuffle(forward_queue)
40
- return forward_queue
41
- elif self.schedule_heuristic == "dfs-weight":
51
+ waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
52
+ return waiting_queue
53
+ elif self.policy == "random":
54
+ random.shuffle(waiting_queue)
55
+ return waiting_queue
56
+ elif self.policy == "dfs-weight":
42
57
  last_node_to_reqs = defaultdict(list)
43
- for req in forward_queue:
58
+ for req in waiting_queue:
44
59
  last_node_to_reqs[req.last_node].append(req)
45
60
 
46
61
  node_to_weight = defaultdict(int)
@@ -52,10 +67,10 @@ class ScheduleHeuristic:
52
67
  self.get_dfs_priority(
53
68
  self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
54
69
  )
55
- assert len(q) == len(forward_queue)
70
+ assert len(q) == len(waiting_queue)
56
71
  return q
57
72
  else:
58
- raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
73
+ raise ValueError(f"Unknown schedule_policy: {self.policy}")
59
74
 
60
75
  def calc_weight(self, cur_node, node_to_weight):
61
76
  for child in cur_node.children.values():
@@ -1,5 +1,21 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Meta data for requests and batches"""
2
17
 
18
+ import logging
3
19
  import warnings
4
20
  from dataclasses import dataclass
5
21
  from enum import IntEnum, auto
@@ -12,8 +28,9 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
12
28
  from sglang.global_config import global_config
13
29
  from sglang.srt.constrained import RegexGuide
14
30
  from sglang.srt.constrained.jump_forward import JumpForwardMap
15
- from sglang.srt.managers.controller.radix_cache import RadixCache
16
- from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
31
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache
32
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
33
+ from sglang.srt.mem_cache.radix_cache import RadixCache
17
34
 
18
35
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
19
36
 
@@ -25,6 +42,9 @@ global_server_args_dict = {
25
42
  }
26
43
 
27
44
 
45
+ logger = logging.getLogger(__name__)
46
+
47
+
28
48
  class ForwardMode(IntEnum):
29
49
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
30
50
  PREFILL = auto()
@@ -364,7 +384,7 @@ class Batch:
364
384
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
365
385
 
366
386
  if out_cache_loc is None:
367
- print("Prefill out of memory. This should never happen.")
387
+ logger.error("Prefill out of memory. This should never happen.")
368
388
  self.tree_cache.pretty_print()
369
389
  exit()
370
390
 
@@ -467,15 +487,33 @@ class Batch:
467
487
  req = self.reqs[idx]
468
488
  retracted_reqs.append(req)
469
489
 
470
- # TODO: apply more fine-grained retraction
471
- last_uncached_pos = len(req.prefix_indices)
472
- token_indices = self.req_to_token_pool.req_to_token[
473
- req_pool_indices_cpu[idx]
474
- ][last_uncached_pos : seq_lens_cpu[idx]]
475
- self.token_to_kv_pool.free(token_indices)
476
-
477
- # release the last node
478
- self.tree_cache.dec_lock_ref(req.last_node)
490
+ if isinstance(self.tree_cache, ChunkCache):
491
+ # ChunkCache does not have eviction
492
+ token_indices = self.req_to_token_pool.req_to_token[
493
+ req_pool_indices_cpu[idx]
494
+ ][: seq_lens_cpu[idx]]
495
+ self.token_to_kv_pool.free(token_indices)
496
+ self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
497
+ del self.tree_cache.entries[req.rid]
498
+ else:
499
+ # TODO: apply more fine-grained retraction
500
+ last_uncached_pos = len(req.prefix_indices)
501
+ token_indices = self.req_to_token_pool.req_to_token[
502
+ req_pool_indices_cpu[idx]
503
+ ][last_uncached_pos : seq_lens_cpu[idx]]
504
+ self.token_to_kv_pool.free(token_indices)
505
+ self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
506
+
507
+ # release the last node
508
+ self.tree_cache.dec_lock_ref(req.last_node)
509
+
510
+ # NOTE(lsyin): we should use the newly evictable memory instantly.
511
+ residual_size = (
512
+ len(sorted_indices) * global_config.retract_decode_steps
513
+ - self.token_to_kv_pool.available_size()
514
+ )
515
+ residual_size = max(0, residual_size)
516
+ self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
479
517
 
480
518
  req.prefix_indices = None
481
519
  req.last_node = None
@@ -556,6 +594,7 @@ class Batch:
556
594
  if req_pool_indices_cpu is None:
557
595
  req_pool_indices_cpu = self.req_pool_indices.tolist()
558
596
  self.tree_cache.cache_req(
597
+ rid=req.rid,
559
598
  token_ids=cur_all_ids,
560
599
  last_uncached_pos=len(req.prefix_indices),
561
600
  req_pool_idx=req_pool_indices_cpu[i],
@@ -598,7 +637,7 @@ class Batch:
598
637
  self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
599
638
 
600
639
  if self.out_cache_loc is None:
601
- print("Decode out of memory. This should never happen.")
640
+ logger.error("Decode out of memory. This should never happen.")
602
641
  self.tree_cache.pretty_print()
603
642
  exit()
604
643
 
@@ -762,7 +801,7 @@ class InputMetadata:
762
801
  flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
763
802
  flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
764
803
  flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
765
- use_ragged: bool = False
804
+ flashinfer_use_ragged: bool = False
766
805
 
767
806
  @classmethod
768
807
  def create(
@@ -778,10 +817,10 @@ class InputMetadata:
778
817
  return_logprob=False,
779
818
  skip_flashinfer_init=False,
780
819
  ):
781
- use_ragged = False
820
+ flashinfer_use_ragged = False
782
821
  if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
783
822
  if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
784
- use_ragged = True
823
+ flashinfer_use_ragged = True
785
824
  init_flashinfer_args(
786
825
  forward_mode,
787
826
  model_runner,
@@ -789,7 +828,7 @@ class InputMetadata:
789
828
  seq_lens,
790
829
  prefix_lens,
791
830
  model_runner.flashinfer_decode_wrapper,
792
- use_ragged,
831
+ flashinfer_use_ragged,
793
832
  )
794
833
 
795
834
  batch_size = len(req_pool_indices)
@@ -844,7 +883,7 @@ class InputMetadata:
844
883
  flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
845
884
  flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
846
885
  flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
847
- use_ragged=use_ragged,
886
+ flashinfer_use_ragged=flashinfer_use_ragged,
848
887
  )
849
888
 
850
889
  if model_runner.server_args.disable_flashinfer:
@@ -865,7 +904,7 @@ def init_flashinfer_args(
865
904
  seq_lens,
866
905
  prefix_lens,
867
906
  flashinfer_decode_wrapper,
868
- use_ragged=False,
907
+ flashinfer_use_ragged=False,
869
908
  ):
870
909
  """Init auxiliary variables for FlashInfer attention backend."""
871
910
  num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
@@ -874,7 +913,7 @@ def init_flashinfer_args(
874
913
  batch_size = len(req_pool_indices)
875
914
  total_num_tokens = int(torch.sum(seq_lens))
876
915
 
877
- if use_ragged:
916
+ if flashinfer_use_ragged:
878
917
  paged_kernel_lens = prefix_lens
879
918
  else:
880
919
  paged_kernel_lens = seq_lens
@@ -910,7 +949,7 @@ def init_flashinfer_args(
910
949
  qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
911
950
  qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
912
951
 
913
- if use_ragged:
952
+ if flashinfer_use_ragged:
914
953
  model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
915
954
  model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
916
955
  qo_indptr,
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """TokenizerManager is a process that tokenizes the text."""
2
17
 
3
18
  import asyncio
@@ -6,7 +21,7 @@ import dataclasses
6
21
  import logging
7
22
  import multiprocessing as mp
8
23
  import os
9
- from typing import Dict, List
24
+ from typing import Dict, List, Tuple
10
25
 
11
26
  import numpy as np
12
27
  import transformers
@@ -69,6 +84,7 @@ class TokenizerManager:
69
84
  trust_remote_code=server_args.trust_remote_code,
70
85
  model_overide_args=model_overide_args,
71
86
  )
87
+
72
88
  if server_args.context_length is not None:
73
89
  self.context_len = server_args.context_length
74
90
  else:
@@ -137,31 +153,33 @@ class TokenizerManager:
137
153
  self, obj, request, index=None, is_cache_for_prefill=False
138
154
  ):
139
155
  if not is_cache_for_prefill:
140
- rid = obj.rid if index is None else obj.rid[index]
141
- input_text = obj.text if index is None else obj.text[index]
156
+ not_use_index = not (index is not None)
157
+ rid = obj.rid if not_use_index else obj.rid[index]
158
+ input_text = obj.text if not_use_index else obj.text[index]
142
159
  input_ids = (
143
160
  self.tokenizer.encode(input_text)
144
161
  if obj.input_ids is None
145
162
  else obj.input_ids
146
163
  )
147
- if index is not None and obj.input_ids:
164
+ if not not_use_index and obj.input_ids:
148
165
  input_ids = obj.input_ids[index]
149
166
 
150
167
  self._validate_input_length(input_ids)
168
+
151
169
  sampling_params = self._get_sampling_params(
152
- obj.sampling_params if index is None else obj.sampling_params[index]
170
+ obj.sampling_params if not_use_index else obj.sampling_params[index]
153
171
  )
154
172
  pixel_values, image_hash, image_size = await self._get_pixel_values(
155
- obj.image_data if index is None else obj.image_data[index]
173
+ obj.image_data if not_use_index else obj.image_data[index]
156
174
  )
157
175
  return_logprob = (
158
- obj.return_logprob if index is None else obj.return_logprob[index]
176
+ obj.return_logprob if not_use_index else obj.return_logprob[index]
159
177
  )
160
178
  logprob_start_len = (
161
- obj.logprob_start_len if index is None else obj.logprob_start_len[index]
179
+ obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
162
180
  )
163
181
  top_logprobs_num = (
164
- obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
182
+ obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
165
183
  )
166
184
  else:
167
185
  if isinstance(obj.text, list):
@@ -209,7 +227,7 @@ class TokenizerManager:
209
227
 
210
228
  async def _handle_batch_request(self, obj: GenerateReqInput, request):
211
229
  batch_size = obj.batch_size
212
- parallel_sample_num = obj.sampling_params[0].get("n", 1)
230
+ parallel_sample_num = obj.parallel_sample_num
213
231
 
214
232
  if parallel_sample_num != 1:
215
233
  # Send prefill requests to cache the common input
@@ -226,7 +244,6 @@ class TokenizerManager:
226
244
  obj.input_ids = input_id_result
227
245
  elif input_id_result is not None:
228
246
  obj.input_ids = input_id_result[0]
229
-
230
247
  # First send out all requests
231
248
  for i in range(batch_size):
232
249
  for j in range(parallel_sample_num):
@@ -234,7 +251,7 @@ class TokenizerManager:
234
251
  continue
235
252
  index = i * parallel_sample_num + j
236
253
  if parallel_sample_num != 1:
237
- # Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
254
+ # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
238
255
  index += batch_size - 1 - i
239
256
  rid = obj.rid[index]
240
257
  if parallel_sample_num == 1:
@@ -469,7 +486,9 @@ class TokenizerManager:
469
486
  )
470
487
  return ret
471
488
 
472
- def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool):
489
+ def detokenize_logprob_tokens(
490
+ self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
491
+ ):
473
492
  if not decode_to_text:
474
493
  return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
475
494
 
@@ -481,9 +500,13 @@ class TokenizerManager:
481
500
  ]
482
501
 
483
502
  def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
484
- for i, t in enumerate(top_logprobs):
485
- if t:
486
- top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
503
+ # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
504
+ # We should batch all top-k tokens in all positions.
505
+ for i, token_top_logprobs in enumerate(top_logprobs):
506
+ if token_top_logprobs:
507
+ top_logprobs[i] = self.detokenize_logprob_tokens(
508
+ token_top_logprobs, decode_to_text
509
+ )
487
510
  return top_logprobs
488
511
 
489
512