sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,17 @@
1
+ """DetokenizerManager is a process that detokenizes the token ids."""
2
+
1
3
  import asyncio
4
+ import inspect
2
5
 
3
6
  import uvloop
4
7
  import zmq
5
8
  import zmq.asyncio
9
+
6
10
  from sglang.srt.hf_transformers_utils import get_tokenizer
11
+ from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
7
12
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
8
13
  from sglang.srt.server_args import PortArgs, ServerArgs
9
- from sglang.srt.utils import get_exception_traceback
14
+ from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
10
15
 
11
16
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
12
17
 
@@ -32,48 +37,43 @@ class DetokenizerManager:
32
37
 
33
38
  async def handle_loop(self):
34
39
  while True:
35
- recv_obj = await self.recv_from_router.recv_pyobj()
40
+ recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
41
+ assert isinstance(recv_obj, BatchTokenIDOut)
36
42
 
37
- if isinstance(recv_obj, BatchTokenIDOut):
38
- output_tokens = recv_obj.output_tokens
43
+ # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
44
+ surr_texts = self.tokenizer.batch_decode(
45
+ recv_obj.surr_output_ids,
46
+ skip_special_tokens=recv_obj.skip_special_tokens[0],
47
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
48
+ )
49
+ read_texts = self.tokenizer.batch_decode(
50
+ recv_obj.read_output_ids,
51
+ skip_special_tokens=recv_obj.skip_special_tokens[0],
52
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
53
+ )
39
54
 
40
- # TODO(lmzheng): handle skip_special_tokens per request
41
- output_strs = self.tokenizer.batch_decode(
42
- output_tokens,
43
- skip_special_tokens=recv_obj.skip_special_tokens[0],
44
- )
55
+ # Trim stop str
56
+ # TODO(lmzheng): handle the case where multiple stop strs are hit
57
+ output_strs = []
58
+ for i in range(len(recv_obj.rids)):
59
+ new_text = read_texts[i][len(surr_texts[i]) :]
60
+ if recv_obj.finished_reason[i] is None:
61
+ new_text = find_printable_text(new_text)
62
+ output_strs.append(recv_obj.decoded_texts[i] + new_text)
45
63
 
46
- # Trim stop str
47
- # TODO(lmzheng): handle the case where multiple stop strs are hit
48
- for i in range(len(output_strs)):
49
- if recv_obj.hit_stop_str[i] is not None:
50
- pos = output_strs[i].find(recv_obj.hit_stop_str[i])
51
- if pos != -1:
52
- output_strs[i] = output_strs[i][:pos]
53
-
54
- if len(output_tokens[i]) > 0:
55
- first_token = self.tokenizer.convert_ids_to_tokens(
56
- int(output_tokens[i][0])
57
- )
58
- if not isinstance(first_token, str):
59
- first_token = first_token.decode("utf-8", errors="ignore")
60
- if first_token.startswith("▁"):
61
- output_strs[i] = " " + output_strs[i]
62
-
63
- output_strs[i] = (
64
- recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
65
- )
66
-
67
- self.send_to_tokenizer.send_pyobj(
68
- BatchStrOut(
69
- recv_obj.rids,
70
- output_strs,
71
- recv_obj.meta_info,
72
- recv_obj.finished,
73
- )
64
+ if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
65
+ pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
66
+ if pos != -1:
67
+ output_strs[i] = output_strs[i][:pos]
68
+
69
+ self.send_to_tokenizer.send_pyobj(
70
+ BatchStrOut(
71
+ rids=recv_obj.rids,
72
+ output_strs=output_strs,
73
+ meta_info=recv_obj.meta_info,
74
+ finished_reason=recv_obj.finished_reason,
74
75
  )
75
- else:
76
- raise ValueError(f"Invalid object: {recv_obj}")
76
+ )
77
77
 
78
78
 
79
79
  def start_detokenizer_process(
@@ -81,9 +81,11 @@ def start_detokenizer_process(
81
81
  port_args: PortArgs,
82
82
  pipe_writer,
83
83
  ):
84
+ graceful_registry(inspect.currentframe().f_code.co_name)
85
+
84
86
  try:
85
87
  manager = DetokenizerManager(server_args, port_args)
86
- except Exception as e:
88
+ except Exception:
87
89
  pipe_writer.send(get_exception_traceback())
88
90
  raise
89
91
  pipe_writer.send("init ok")
@@ -1,14 +1,22 @@
1
+ """
2
+ The definition of objects transfered between different
3
+ processes (TokenizerManager, DetokenizerManager, Controller).
4
+ """
5
+
1
6
  import uuid
2
7
  from dataclasses import dataclass
3
8
  from typing import Dict, List, Optional, Union
4
9
 
10
+ from sglang.srt.managers.controller.infer_batch import BaseFinishReason
5
11
  from sglang.srt.sampling_params import SamplingParams
6
12
 
7
13
 
8
14
  @dataclass
9
15
  class GenerateReqInput:
10
16
  # The input prompt
11
- text: Union[List[str], str]
17
+ text: Optional[Union[List[str], str]] = None
18
+ # The token ids for text; one can either specify text or input_ids
19
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None
12
20
  # The image input
13
21
  image_data: Optional[Union[List[str], str]] = None
14
22
  # The sampling_params
@@ -19,13 +27,24 @@ class GenerateReqInput:
19
27
  return_logprob: Optional[Union[List[bool], bool]] = None
20
28
  # The start location of the prompt for return_logprob
21
29
  logprob_start_len: Optional[Union[List[int], int]] = None
30
+ # The number of top logprobs to return
31
+ top_logprobs_num: Optional[Union[List[int], int]] = None
22
32
  # Whether to detokenize tokens in logprobs
23
33
  return_text_in_logprobs: bool = False
24
34
  # Whether to stream output
25
35
  stream: bool = False
26
36
 
27
37
  def post_init(self):
28
- is_single = isinstance(self.text, str)
38
+ if (self.text is None and self.input_ids is None) or (
39
+ self.text is not None and self.input_ids is not None
40
+ ):
41
+ raise ValueError("Either text or input_ids should be provided.")
42
+
43
+ if self.text is not None:
44
+ is_single = isinstance(self.text, str)
45
+ else:
46
+ is_single = isinstance(self.input_ids[0], int)
47
+ self.is_single = is_single
29
48
 
30
49
  if is_single:
31
50
  if self.sampling_params is None:
@@ -36,8 +55,10 @@ class GenerateReqInput:
36
55
  self.return_logprob = False
37
56
  if self.logprob_start_len is None:
38
57
  self.logprob_start_len = 0
58
+ if self.top_logprobs_num is None:
59
+ self.top_logprobs_num = 0
39
60
  else:
40
- num = len(self.text)
61
+ num = len(self.text) if self.text is not None else len(self.input_ids)
41
62
 
42
63
  if self.image_data is None:
43
64
  self.image_data = [None] * num
@@ -52,7 +73,8 @@ class GenerateReqInput:
52
73
  if self.rid is None:
53
74
  self.rid = [uuid.uuid4().hex for _ in range(num)]
54
75
  else:
55
- assert isinstance(self.rid, list)
76
+ if not isinstance(self.rid, list):
77
+ raise ValueError("The rid should be a list.")
56
78
 
57
79
  if self.return_logprob is None:
58
80
  self.return_logprob = [False] * num
@@ -64,6 +86,11 @@ class GenerateReqInput:
64
86
  elif not isinstance(self.logprob_start_len, list):
65
87
  self.logprob_start_len = [self.logprob_start_len] * num
66
88
 
89
+ if self.top_logprobs_num is None:
90
+ self.top_logprobs_num = [0] * num
91
+ elif not isinstance(self.top_logprobs_num, list):
92
+ self.top_logprobs_num = [self.top_logprobs_num] * num
93
+
67
94
 
68
95
  @dataclass
69
96
  class TokenizedGenerateReqInput:
@@ -76,26 +103,28 @@ class TokenizedGenerateReqInput:
76
103
  sampling_params: SamplingParams
77
104
  return_logprob: bool
78
105
  logprob_start_len: int
106
+ top_logprobs_num: int
79
107
  stream: bool
80
108
 
81
109
 
82
110
  @dataclass
83
111
  class BatchTokenIDOut:
84
112
  rids: List[str]
85
- output_tokens: List[List[int]]
86
- output_and_jump_forward_strs: List[str]
87
- hit_stop_str: List[Optional[str]]
113
+ decoded_texts: List[str]
114
+ surr_output_ids: List[List[int]]
115
+ read_output_ids: List[List[int]]
88
116
  skip_special_tokens: List[bool]
117
+ spaces_between_special_tokens: List[bool]
89
118
  meta_info: List[Dict]
90
- finished: List[bool]
119
+ finished_reason: List[BaseFinishReason]
91
120
 
92
121
 
93
122
  @dataclass
94
123
  class BatchStrOut:
95
124
  rids: List[str]
96
- output_str: List[str]
125
+ output_strs: List[str]
97
126
  meta_info: List[Dict]
98
- finished: List[bool]
127
+ finished_reason: List[BaseFinishReason]
99
128
 
100
129
 
101
130
  @dataclass
@@ -103,6 +132,11 @@ class FlushCacheReq:
103
132
  pass
104
133
 
105
134
 
135
+ @dataclass
136
+ class AbortReq:
137
+ rid: str
138
+
139
+
106
140
  @dataclass
107
141
  class DetokenizeReqInput:
108
142
  input_ids: List[int]
@@ -1,15 +1,20 @@
1
+ """TokenizerManager is a process that tokenizes the text."""
2
+
1
3
  import asyncio
2
4
  import concurrent.futures
3
5
  import dataclasses
6
+ import logging
4
7
  import multiprocessing as mp
5
8
  import os
6
- from typing import List
9
+ from typing import Dict, List
7
10
 
8
11
  import numpy as np
9
12
  import transformers
10
13
  import uvloop
11
14
  import zmq
12
15
  import zmq.asyncio
16
+ from fastapi import BackgroundTasks
17
+
13
18
  from sglang.srt.hf_transformers_utils import (
14
19
  get_config,
15
20
  get_context_length,
@@ -17,8 +22,9 @@ from sglang.srt.hf_transformers_utils import (
17
22
  get_tokenizer,
18
23
  )
19
24
  from sglang.srt.managers.io_struct import (
25
+ AbortReq,
20
26
  BatchStrOut,
21
- DetokenizeReqInput,
27
+ BatchTokenIDOut,
22
28
  FlushCacheReq,
23
29
  GenerateReqInput,
24
30
  TokenizedGenerateReqInput,
@@ -26,54 +32,19 @@ from sglang.srt.managers.io_struct import (
26
32
  from sglang.srt.mm_utils import expand2square, process_anyres_image
27
33
  from sglang.srt.sampling_params import SamplingParams
28
34
  from sglang.srt.server_args import PortArgs, ServerArgs
29
- from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
35
+ from sglang.srt.utils import is_multimodal_model, load_image
36
+ from sglang.utils import get_exception_traceback
30
37
 
31
38
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
32
39
 
40
+ logger = logging.getLogger(__name__)
41
+
33
42
 
34
43
  @dataclasses.dataclass
35
44
  class ReqState:
36
45
  out_list: List
37
46
  finished: bool
38
47
  event: asyncio.Event
39
- lock: asyncio.Lock
40
-
41
-
42
- global global_processor
43
-
44
-
45
- def init_global_processor(server_args: ServerArgs):
46
- global global_processor
47
- transformers.logging.set_verbosity_error()
48
- global_processor = get_processor(
49
- server_args.tokenizer_path,
50
- tokenizer_mode=server_args.tokenizer_mode,
51
- trust_remote_code=server_args.trust_remote_code,
52
- )
53
-
54
-
55
- def get_pixel_values(
56
- image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
57
- ):
58
- try:
59
- processor = processor or global_processor
60
- image = load_image(image_data)
61
- image_hash = hash(image_data)
62
- if image_aspect_ratio == "pad":
63
- image = expand2square(
64
- image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
65
- )
66
- pixel_values = processor.image_processor(image)["pixel_values"][0]
67
- elif image_aspect_ratio == "anyres":
68
- pixel_values = process_anyres_image(
69
- image, processor.image_processor, image_grid_pinpoints
70
- )
71
- else:
72
- pixel_values = processor.image_processor(image)["pixel_values"][0]
73
- pixel_values = pixel_values.astype(np.float16)
74
- return pixel_values, image_hash, image.size
75
- except Exception:
76
- print("Exception in TokenizerManager:\n" + get_exception_traceback())
77
48
 
78
49
 
79
50
  class TokenizerManager:
@@ -81,6 +52,7 @@ class TokenizerManager:
81
52
  self,
82
53
  server_args: ServerArgs,
83
54
  port_args: PortArgs,
55
+ model_overide_args: dict = None,
84
56
  ):
85
57
  self.server_args = server_args
86
58
 
@@ -93,9 +65,10 @@ class TokenizerManager:
93
65
 
94
66
  self.model_path = server_args.model_path
95
67
  self.hf_config = get_config(
96
- self.model_path, trust_remote_code=server_args.trust_remote_code
68
+ self.model_path,
69
+ trust_remote_code=server_args.trust_remote_code,
70
+ model_overide_args=model_overide_args,
97
71
  )
98
-
99
72
  self.context_len = get_context_length(self.hf_config)
100
73
 
101
74
  if is_multimodal_model(self.model_path):
@@ -119,7 +92,7 @@ class TokenizerManager:
119
92
  )
120
93
 
121
94
  self.to_create_loop = True
122
- self.rid_to_state = {} # Dict[str -> ReqState]
95
+ self.rid_to_state: Dict[str, ReqState] = {}
123
96
 
124
97
  async def get_pixel_values(self, image_data):
125
98
  aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
@@ -140,15 +113,26 @@ class TokenizerManager:
140
113
  image_data, aspect_ratio, grid_pinpoints, self.processor
141
114
  )
142
115
 
143
- async def generate_request(self, obj: GenerateReqInput):
116
+ async def generate_request(self, obj: GenerateReqInput, request=None):
144
117
  if self.to_create_loop:
145
- await self.create_handle_loop()
146
-
147
- is_single = isinstance(obj.text, str)
118
+ self.create_handle_loop()
148
119
 
120
+ obj.post_init()
121
+ is_single = obj.is_single
149
122
  if is_single:
150
123
  rid = obj.rid
151
- input_ids = self.tokenizer.encode(obj.text)
124
+
125
+ if obj.input_ids is None:
126
+ input_ids = self.tokenizer.encode(obj.text)
127
+ else:
128
+ input_ids = obj.input_ids
129
+
130
+ if len(input_ids) >= self.context_len:
131
+ raise ValueError(
132
+ f"The input ({len(input_ids)} tokens) is longer than the "
133
+ f"model's context length ({self.context_len} tokens)."
134
+ )
135
+
152
136
  sampling_params = SamplingParams(**obj.sampling_params)
153
137
  if sampling_params.max_new_tokens != 0:
154
138
  sampling_params.normalize(self.tokenizer)
@@ -174,29 +158,64 @@ class TokenizerManager:
174
158
  sampling_params=sampling_params,
175
159
  return_logprob=obj.return_logprob,
176
160
  logprob_start_len=obj.logprob_start_len,
161
+ top_logprobs_num=obj.top_logprobs_num,
177
162
  stream=obj.stream,
178
163
  )
179
164
  self.send_to_router.send_pyobj(tokenized_obj)
180
165
 
181
- lock = asyncio.Lock()
182
166
  event = asyncio.Event()
183
- state = ReqState([], False, event, lock)
167
+ state = ReqState([], False, event)
184
168
  self.rid_to_state[rid] = state
185
169
 
186
170
  while True:
187
- await event.wait()
188
- yield state.out_list[-1]
171
+ try:
172
+ await asyncio.wait_for(event.wait(), timeout=4)
173
+ except asyncio.TimeoutError:
174
+ if request is not None and await request.is_disconnected():
175
+ self.abort_request(rid)
176
+ raise ValueError(f"Abort request {rid}")
177
+ continue
178
+
179
+ out = self.convert_logprob_style(
180
+ state.out_list[-1],
181
+ obj.return_logprob,
182
+ obj.top_logprobs_num,
183
+ obj.return_text_in_logprobs,
184
+ )
185
+
186
+ if self.server_args.log_requests and state.finished:
187
+ logger.info(f"in={obj.text}, out={out}")
188
+
189
189
  state.out_list = []
190
190
  if state.finished:
191
191
  del self.rid_to_state[rid]
192
+
193
+ yield out
194
+
192
195
  break
196
+
193
197
  event.clear()
198
+
199
+ yield out
194
200
  else:
195
- assert obj.stream is False
196
- bs = len(obj.text)
201
+ if obj.stream:
202
+ raise ValueError("Do not support stream for batch mode.")
203
+
204
+ if obj.input_ids is None:
205
+ bs = len(obj.text)
206
+ else:
207
+ bs = len(obj.input_ids)
208
+
197
209
  for i in range(bs):
198
210
  rid = obj.rid[i]
199
- input_ids = self.tokenizer.encode(obj.text[i])
211
+
212
+ if obj.input_ids is None:
213
+ input_text = obj.text[i]
214
+ input_ids = self.tokenizer.encode(obj.text[i])
215
+ else:
216
+ input_text = None
217
+ input_ids = obj.input_ids[i]
218
+
200
219
  sampling_params = SamplingParams(**obj.sampling_params[i])
201
220
  if sampling_params.max_new_tokens != 0:
202
221
  sampling_params.normalize(self.tokenizer)
@@ -209,7 +228,7 @@ class TokenizerManager:
209
228
  )
210
229
  tokenized_obj = TokenizedGenerateReqInput(
211
230
  rid=rid,
212
- input_text=obj.text[i],
231
+ input_text=input_text,
213
232
  input_ids=input_ids,
214
233
  pixel_values=pixel_values,
215
234
  image_hash=image_hash,
@@ -217,53 +236,176 @@ class TokenizerManager:
217
236
  sampling_params=sampling_params,
218
237
  return_logprob=obj.return_logprob[i],
219
238
  logprob_start_len=obj.logprob_start_len[i],
239
+ top_logprobs_num=obj.top_logprobs_num[i],
220
240
  stream=obj.stream,
221
241
  )
222
242
  self.send_to_router.send_pyobj(tokenized_obj)
223
243
 
224
- lock = asyncio.Lock()
225
244
  event = asyncio.Event()
226
- state = ReqState([], False, event, lock)
245
+ state = ReqState([], False, event)
227
246
  self.rid_to_state[rid] = state
228
247
 
229
248
  output_list = []
230
249
  for i in range(bs):
231
250
  rid = obj.rid[i]
232
251
  state = self.rid_to_state[rid]
233
- await state.event.wait()
234
- output_list.append(state.out_list[-1])
252
+
253
+ while True:
254
+ try:
255
+ await asyncio.wait_for(state.event.wait(), timeout=4)
256
+ break
257
+ except asyncio.TimeoutError:
258
+ if request is not None and await request.is_disconnected():
259
+ for rid in obj.rid:
260
+ self.abort_request(rid)
261
+ raise ValueError(f"Abort request {rid}")
262
+ continue
263
+
264
+ output_list.append(
265
+ self.convert_logprob_style(
266
+ state.out_list[-1],
267
+ obj.return_logprob[i],
268
+ obj.top_logprobs_num[i],
269
+ obj.return_text_in_logprobs,
270
+ )
271
+ )
235
272
  assert state.finished
236
273
  del self.rid_to_state[rid]
237
274
 
238
275
  yield output_list
239
276
 
240
- async def detokenize(self, obj: DetokenizeReqInput):
241
- token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
242
- return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
277
+ def flush_cache(self):
278
+ req = FlushCacheReq()
279
+ self.send_to_router.send_pyobj(req)
280
+
281
+ def abort_request(self, rid):
282
+ if rid not in self.rid_to_state:
283
+ return
284
+ del self.rid_to_state[rid]
285
+ req = AbortReq(rid)
286
+ self.send_to_router.send_pyobj(req)
287
+
288
+ def create_abort_task(self, obj: GenerateReqInput):
289
+ # Abort the request if the client is disconnected.
290
+ async def abort_request():
291
+ await asyncio.sleep(3)
292
+ if obj.is_single:
293
+ self.abort_request(obj.rid)
294
+ else:
295
+ for rid in obj.rids:
296
+ self.abort_request(rid)
243
297
 
244
- async def flush_cache(self):
245
- flush_cache_req = FlushCacheReq()
246
- self.send_to_router.send_pyobj(flush_cache_req)
298
+ background_tasks = BackgroundTasks()
299
+ background_tasks.add_task(abort_request)
300
+ return background_tasks
247
301
 
248
- async def create_handle_loop(self):
302
+ def create_handle_loop(self):
249
303
  self.to_create_loop = False
250
304
  loop = asyncio.get_event_loop()
251
305
  loop.create_task(self.handle_loop())
252
306
 
253
307
  async def handle_loop(self):
254
308
  while True:
255
- recv_obj = await self.recv_from_detokenizer.recv_pyobj()
256
-
257
- if isinstance(recv_obj, BatchStrOut):
258
- for i, rid in enumerate(recv_obj.rids):
259
- recv_obj.meta_info[i]["id"] = rid
260
- out_dict = {
261
- "text": recv_obj.output_str[i],
262
- "meta_info": recv_obj.meta_info[i],
263
- }
264
- state = self.rid_to_state[rid]
265
- state.out_list.append(out_dict)
266
- state.finished = recv_obj.finished[i]
267
- state.event.set()
309
+ recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
310
+ assert isinstance(recv_obj, BatchStrOut)
311
+
312
+ for i, rid in enumerate(recv_obj.rids):
313
+ state = self.rid_to_state.get(rid, None)
314
+ if state is None:
315
+ continue
316
+
317
+ recv_obj.meta_info[i]["id"] = rid
318
+ out_dict = {
319
+ "text": recv_obj.output_strs[i],
320
+ "meta_info": recv_obj.meta_info[i],
321
+ }
322
+ state.out_list.append(out_dict)
323
+ state.finished = recv_obj.finished_reason[i] is not None
324
+ state.event.set()
325
+
326
+ def convert_logprob_style(
327
+ self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
328
+ ):
329
+ if return_logprob:
330
+ ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
331
+ ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
332
+ )
333
+ ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
334
+ ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
335
+ )
336
+
337
+ if top_logprobs_num > 0:
338
+ ret["meta_info"][
339
+ "prefill_top_logprobs"
340
+ ] = self.detokenize_top_logprobs_tokens(
341
+ ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
342
+ )
343
+ ret["meta_info"][
344
+ "decode_top_logprobs"
345
+ ] = self.detokenize_top_logprobs_tokens(
346
+ ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
347
+ )
348
+ return ret
349
+
350
+ def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
351
+ if not decode_to_text:
352
+ return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
353
+
354
+ token_ids = [tid for _, tid in token_logprobs]
355
+ token_texts = self.tokenizer.batch_decode(token_ids)
356
+ return [
357
+ (logprob, token_id, token_text)
358
+ for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
359
+ ]
360
+
361
+ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
362
+ for i, t in enumerate(top_logprobs):
363
+ if t:
364
+ top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
365
+ return top_logprobs
366
+
367
+
368
+ global global_processor
369
+
370
+
371
+ def init_global_processor(server_args: ServerArgs):
372
+ global global_processor
373
+ transformers.logging.set_verbosity_error()
374
+ global_processor = get_processor(
375
+ server_args.tokenizer_path,
376
+ tokenizer_mode=server_args.tokenizer_mode,
377
+ trust_remote_code=server_args.trust_remote_code,
378
+ )
379
+
380
+
381
+ def get_pixel_values(
382
+ image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
383
+ ):
384
+ try:
385
+ processor = processor or global_processor
386
+ image, image_size = load_image(image_data)
387
+ if image_size is not None:
388
+ image_hash = hash(image_data)
389
+ pixel_values = processor.image_processor(image)["pixel_values"]
390
+ for _ in range(len(pixel_values)):
391
+ pixel_values[_] = pixel_values[_].astype(np.float16)
392
+ pixel_values = np.stack(pixel_values, axis=0)
393
+ return pixel_values, image_hash, image_size
394
+ else:
395
+ image_hash = hash(image_data)
396
+ if image_aspect_ratio == "pad":
397
+ image = expand2square(
398
+ image,
399
+ tuple(int(x * 255) for x in processor.image_processor.image_mean),
400
+ )
401
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
402
+ elif image_aspect_ratio == "anyres":
403
+ pixel_values = process_anyres_image(
404
+ image, processor.image_processor, image_grid_pinpoints
405
+ )
268
406
  else:
269
- raise ValueError(f"Invalid object: {recv_obj}")
407
+ pixel_values = processor.image_processor(image)["pixel_values"][0]
408
+ pixel_values = pixel_values.astype(np.float16)
409
+ return pixel_values, image_hash, image.size
410
+ except Exception:
411
+ print("Exception in TokenizerManager:\n" + get_exception_traceback())