sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +76 -15
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/fsm_cache.py +10 -3
  14. sglang/srt/constrained/grammar.py +190 -0
  15. sglang/srt/hf_transformers_utils.py +20 -5
  16. sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  18. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  19. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  20. sglang/srt/layers/fused_moe/layer.py +28 -0
  21. sglang/srt/layers/logits_processor.py +5 -5
  22. sglang/srt/layers/quantization/base_config.py +16 -1
  23. sglang/srt/layers/rotary_embedding.py +15 -48
  24. sglang/srt/layers/sampler.py +51 -39
  25. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  26. sglang/srt/managers/data_parallel_controller.py +8 -7
  27. sglang/srt/managers/detokenizer_manager.py +11 -9
  28. sglang/srt/managers/image_processor.py +4 -3
  29. sglang/srt/managers/io_struct.py +80 -78
  30. sglang/srt/managers/schedule_batch.py +46 -52
  31. sglang/srt/managers/schedule_policy.py +24 -13
  32. sglang/srt/managers/scheduler.py +145 -82
  33. sglang/srt/managers/tokenizer_manager.py +236 -334
  34. sglang/srt/managers/tp_worker.py +5 -5
  35. sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  36. sglang/srt/mem_cache/flush_cache.py +1 -1
  37. sglang/srt/mem_cache/memory_pool.py +10 -3
  38. sglang/srt/model_executor/cuda_graph_runner.py +34 -23
  39. sglang/srt/model_executor/forward_batch_info.py +6 -9
  40. sglang/srt/model_executor/model_runner.py +10 -19
  41. sglang/srt/models/baichuan.py +4 -4
  42. sglang/srt/models/chatglm.py +4 -4
  43. sglang/srt/models/commandr.py +1 -1
  44. sglang/srt/models/dbrx.py +5 -5
  45. sglang/srt/models/deepseek.py +4 -4
  46. sglang/srt/models/deepseek_v2.py +4 -4
  47. sglang/srt/models/exaone.py +4 -4
  48. sglang/srt/models/gemma.py +1 -1
  49. sglang/srt/models/gemma2.py +1 -1
  50. sglang/srt/models/gpt2.py +287 -0
  51. sglang/srt/models/gpt_bigcode.py +1 -1
  52. sglang/srt/models/grok.py +4 -4
  53. sglang/srt/models/internlm2.py +4 -4
  54. sglang/srt/models/llama.py +15 -7
  55. sglang/srt/models/llama_embedding.py +2 -10
  56. sglang/srt/models/llama_reward.py +5 -0
  57. sglang/srt/models/minicpm.py +4 -4
  58. sglang/srt/models/minicpm3.py +4 -4
  59. sglang/srt/models/mixtral.py +7 -5
  60. sglang/srt/models/mixtral_quant.py +4 -4
  61. sglang/srt/models/mllama.py +5 -5
  62. sglang/srt/models/olmo.py +4 -4
  63. sglang/srt/models/olmoe.py +4 -4
  64. sglang/srt/models/qwen.py +4 -4
  65. sglang/srt/models/qwen2.py +4 -4
  66. sglang/srt/models/qwen2_moe.py +4 -4
  67. sglang/srt/models/qwen2_vl.py +4 -8
  68. sglang/srt/models/stablelm.py +4 -4
  69. sglang/srt/models/torch_native_llama.py +4 -4
  70. sglang/srt/models/xverse.py +4 -4
  71. sglang/srt/models/xverse_moe.py +4 -4
  72. sglang/srt/openai_api/adapter.py +52 -66
  73. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  74. sglang/srt/sampling/sampling_batch_info.py +7 -13
  75. sglang/srt/sampling/sampling_params.py +5 -7
  76. sglang/srt/server.py +41 -33
  77. sglang/srt/server_args.py +34 -5
  78. sglang/srt/utils.py +40 -56
  79. sglang/test/run_eval.py +2 -0
  80. sglang/test/runners.py +2 -1
  81. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  82. sglang/test/test_utils.py +151 -6
  83. sglang/utils.py +62 -1
  84. sglang/version.py +1 -1
  85. sglang-0.3.5.dist-info/METADATA +344 -0
  86. sglang-0.3.5.dist-info/RECORD +152 -0
  87. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  88. sglang-0.3.4.post1.dist-info/METADATA +0 -900
  89. sglang-0.3.4.post1.dist-info/RECORD +0 -148
  90. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,12 @@ limitations under the License.
16
16
  """TokenizerManager is a process that tokenizes the text."""
17
17
 
18
18
  import asyncio
19
+ import copy
19
20
  import dataclasses
20
- import json
21
21
  import logging
22
22
  import os
23
+ import signal
24
+ import sys
23
25
  from typing import Dict, List, Optional, Tuple, Union
24
26
 
25
27
  import fastapi
@@ -28,12 +30,8 @@ import zmq
28
30
  import zmq.asyncio
29
31
  from fastapi import BackgroundTasks
30
32
 
31
- from sglang.srt.hf_transformers_utils import (
32
- get_config,
33
- get_context_length,
34
- get_processor,
35
- get_tokenizer,
36
- )
33
+ from sglang.srt.configs.model_config import ModelConfig
34
+ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
37
35
  from sglang.srt.managers.image_processor import (
38
36
  get_dummy_image_processor,
39
37
  get_image_processor,
@@ -46,17 +44,17 @@ from sglang.srt.managers.io_struct import (
46
44
  EmbeddingReqInput,
47
45
  FlushCacheReq,
48
46
  GenerateReqInput,
47
+ GetMemPoolSizeReq,
48
+ GetMemPoolSizeReqOutput,
49
49
  ProfileReq,
50
- RewardReqInput,
51
50
  TokenizedEmbeddingReqInput,
52
51
  TokenizedGenerateReqInput,
53
- TokenizedRewardReqInput,
54
52
  UpdateWeightReqInput,
55
53
  UpdateWeightReqOutput,
56
54
  )
57
55
  from sglang.srt.sampling.sampling_params import SamplingParams
58
56
  from sglang.srt.server_args import PortArgs, ServerArgs
59
- from sglang.srt.utils import is_generation_model, is_multimodal_model
57
+ from sglang.srt.utils import get_zmq_socket, kill_child_process
60
58
 
61
59
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
62
60
 
@@ -80,30 +78,32 @@ class TokenizerManager:
80
78
  server_args: ServerArgs,
81
79
  port_args: PortArgs,
82
80
  ):
81
+ # Parse args
83
82
  self.server_args = server_args
84
83
 
85
84
  # Init inter-process communication
86
85
  context = zmq.asyncio.Context(2)
87
- self.recv_from_detokenizer = context.socket(zmq.PULL)
88
- self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}")
89
-
90
- self.send_to_scheduler = context.socket(zmq.PUSH)
91
- self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
86
+ self.recv_from_detokenizer = get_zmq_socket(
87
+ context, zmq.PULL, port_args.tokenizer_ipc_name
88
+ )
89
+ self.send_to_scheduler = get_zmq_socket(
90
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name
91
+ )
92
92
 
93
93
  # Read model args
94
94
  self.model_path = server_args.model_path
95
95
  self.served_model_name = server_args.served_model_name
96
- self.hf_config = get_config(
97
- self.model_path,
96
+ self.model_config = ModelConfig(
97
+ server_args.model_path,
98
98
  trust_remote_code=server_args.trust_remote_code,
99
- model_override_args=json.loads(server_args.json_model_override_args),
100
- )
101
- self.is_generation = is_generation_model(
102
- self.hf_config.architectures, self.server_args.is_embedding
103
- )
104
- self.context_len = server_args.context_length or get_context_length(
105
- self.hf_config
99
+ context_length=server_args.context_length,
100
+ model_override_args=server_args.json_model_override_args,
101
+ is_embedding=server_args.is_embedding,
106
102
  )
103
+
104
+ self.is_generation = self.model_config.is_generation
105
+ self.context_len = self.model_config.context_len
106
+
107
107
  # Create image processor placeholder
108
108
  self.image_processor = get_dummy_image_processor()
109
109
 
@@ -111,7 +111,7 @@ class TokenizerManager:
111
111
  if server_args.skip_tokenizer_init:
112
112
  self.tokenizer = self.processor = None
113
113
  else:
114
- if is_multimodal_model(self.hf_config.architectures):
114
+ if self.model_config.is_multimodal:
115
115
  self.processor = get_processor(
116
116
  server_args.tokenizer_path,
117
117
  tokenizer_mode=server_args.tokenizer_mode,
@@ -122,7 +122,7 @@ class TokenizerManager:
122
122
 
123
123
  # We want to parallelize the image pre-processing so we create an executor for it
124
124
  self.image_processor = get_image_processor(
125
- self.hf_config, server_args, self.processor
125
+ self.model_config.hf_config, server_args, self.processor
126
126
  )
127
127
  else:
128
128
  self.tokenizer = get_tokenizer(
@@ -139,9 +139,12 @@ class TokenizerManager:
139
139
  self.model_update_lock = asyncio.Lock()
140
140
  self.model_update_result = None
141
141
 
142
+ # Others
143
+ self.gracefully_exit = False
144
+
142
145
  async def generate_request(
143
146
  self,
144
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
147
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
145
148
  request: Optional[fastapi.Request] = None,
146
149
  ):
147
150
  if self.to_create_loop:
@@ -152,133 +155,58 @@ class TokenizerManager:
152
155
 
153
156
  if isinstance(obj, EmbeddingReqInput) and self.is_generation:
154
157
  raise ValueError(
155
- "This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
158
+ "This model does not appear to be an embedding model by default. "
159
+ "Please add `--is-embedding` when launching the server or try another model."
156
160
  )
157
161
 
158
- obj.post_init()
162
+ obj.normalize_batch_and_arguments()
159
163
  is_single = obj.is_single
160
164
  if is_single:
161
- async for response in self._handle_single_request(obj, request):
165
+ tokenized_obj = await self._tokenize_one_request(obj)
166
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
167
+ async for response in self._wait_one_response(obj, request):
162
168
  yield response
163
169
  else:
164
170
  async for response in self._handle_batch_request(obj, request):
165
171
  yield response
166
172
 
167
- async def _send_single_request(
173
+ async def _tokenize_one_request(
168
174
  self,
169
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
170
- index: Optional[int] = None,
171
- input_id_index: Optional[int] = None,
172
- is_cache_for_prefill: Optional[bool] = False,
175
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
173
176
  ):
174
- if not is_cache_for_prefill: # The normal case with a single prompt
175
- if index is None:
176
- rid = obj.rid
177
- if hasattr(obj, "conv"):
178
- # reward model
179
- conv = obj.conv
180
- input_text = self.tokenizer.apply_chat_template(
181
- conv, tokenize=False
182
- )
183
- input_ids = self.tokenizer.encode(input_text)
184
- elif obj.input_ids is None:
185
- input_text = obj.text
186
- input_ids = self.tokenizer.encode(input_text)
187
- else:
188
- input_text = obj.text if obj.text is not None else None
189
- input_ids = obj.input_ids
190
-
191
- sampling_params = self._get_sampling_params(obj.sampling_params)
192
- if self.is_generation:
193
- image_inputs = await self.image_processor.process_images_async(
194
- obj.image_data, input_text or input_ids, obj
195
- )
196
- if image_inputs and "input_ids" in image_inputs:
197
- input_ids = image_inputs["input_ids"]
198
- return_logprob = obj.return_logprob
199
- logprob_start_len = obj.logprob_start_len
200
- top_logprobs_num = obj.top_logprobs_num
201
- else:
202
- rid = obj.rid[index]
203
- if hasattr(obj, "conv"):
204
- # reward model
205
- conv = obj.conv[index]
206
- input_text = self.tokenizer.apply_chat_template(
207
- conv, tokenize=False
208
- )
209
- input_ids = self.tokenizer.encode(input_text)
210
- elif obj.input_ids is None:
211
- input_text = obj.text[input_id_index]
212
- input_ids = self.tokenizer.encode(input_text)
213
- else:
214
- input_text = (
215
- obj.text[input_id_index] if obj.text is not None else None
216
- )
217
- input_ids = obj.input_ids[input_id_index]
218
-
219
- sampling_params = self._get_sampling_params(obj.sampling_params[index])
220
- if self.is_generation:
221
- image_inputs = await self.image_processor.process_images_async(
222
- obj.image_data[index], input_text or input_ids, obj
223
- )
224
- if image_inputs and "input_ids" in image_inputs:
225
- input_ids = image_inputs["input_ids"]
226
- return_logprob = obj.return_logprob[index]
227
- logprob_start_len = obj.logprob_start_len[index]
228
- top_logprobs_num = obj.top_logprobs_num[index]
229
-
230
- self._validate_input_length(input_ids)
231
-
232
- else: # A prefill request to cache the common prompt for parallel sampling
233
- assert self.is_generation
234
- if obj.text is not None:
235
- if isinstance(obj.text, list):
236
- input_text = obj.text[input_id_index]
237
- rid = obj.rid[index]
238
- else:
239
- input_text = obj.text
240
- rid = obj.rid[0]
241
- if self.tokenizer is not None:
242
- input_ids = self.tokenizer.encode(input_text)
243
- else:
244
- assert obj.input_ids is not None
245
- input_ids = obj.input_ids
246
- if isinstance(obj.input_ids, list) and isinstance(
247
- obj.input_ids[0], list
248
- ):
249
- # when obj["input_ids"] is List[List[int]]
250
- input_ids = obj.input_ids[input_id_index]
251
- rid = obj.rid[index]
252
- else:
253
- input_ids = obj.input_ids
254
- rid = obj.rid[0]
255
- else:
256
- input_text = None
257
- if isinstance(obj.input_ids, list) and isinstance(
258
- obj.input_ids[0], list
259
- ):
260
- # when obj["input_ids"] is List[List[int]]
261
- input_ids = obj.input_ids[input_id_index]
262
- rid = obj.rid[index]
263
- else:
264
- input_ids = obj.input_ids
265
- rid = obj.rid[0]
177
+ """Tokenize one request."""
178
+ # Tokenize
179
+ input_text = obj.text
180
+ if obj.input_ids is None:
181
+ input_ids = self.tokenizer.encode(input_text)
182
+ else:
183
+ input_ids = obj.input_ids
266
184
 
267
- sampling_params = SamplingParams(**obj.sampling_params[0])
268
- sampling_params.max_new_tokens = 0
185
+ if self.is_generation:
269
186
  image_inputs = await self.image_processor.process_images_async(
270
- obj.image_data[0], input_text or input_ids, obj
187
+ obj.image_data, input_text or input_ids, obj
271
188
  )
272
189
  if image_inputs and "input_ids" in image_inputs:
273
190
  input_ids = image_inputs["input_ids"]
274
- return_logprob = obj.return_logprob[0]
275
- logprob_start_len = obj.logprob_start_len[0]
276
- top_logprobs_num = obj.top_logprobs_num[0]
191
+ return_logprob = obj.return_logprob
192
+ logprob_start_len = obj.logprob_start_len
193
+ top_logprobs_num = obj.top_logprobs_num
277
194
 
278
- # Send to the controller
279
- if self.is_generation:
195
+ if len(input_ids) >= self.context_len:
196
+ raise ValueError(
197
+ f"The input ({len(input_ids)} tokens) is longer than the "
198
+ f"model's context length ({self.context_len} tokens)."
199
+ )
200
+
201
+ # Parse sampling parameters
202
+ sampling_params = SamplingParams(**obj.sampling_params)
203
+ sampling_params.normalize(self.tokenizer)
204
+ sampling_params.verify()
205
+
206
+ # Build return object
207
+ if isinstance(obj, GenerateReqInput):
280
208
  tokenized_obj = TokenizedGenerateReqInput(
281
- rid,
209
+ obj.rid,
282
210
  input_text,
283
211
  input_ids,
284
212
  image_inputs,
@@ -287,230 +215,125 @@ class TokenizerManager:
287
215
  logprob_start_len,
288
216
  top_logprobs_num,
289
217
  obj.stream,
290
- (
291
- obj.lora_path[input_id_index]
292
- if isinstance(obj.lora_path, list)
293
- else obj.lora_path
294
- ),
218
+ obj.lora_path
295
219
  )
296
220
  elif isinstance(obj, EmbeddingReqInput):
297
221
  tokenized_obj = TokenizedEmbeddingReqInput(
298
- rid,
299
- input_text,
300
- input_ids,
301
- sampling_params,
302
- )
303
- else:
304
- assert isinstance(obj, RewardReqInput)
305
- tokenized_obj = TokenizedRewardReqInput(
306
- rid,
222
+ obj.rid,
307
223
  input_text,
308
224
  input_ids,
309
225
  sampling_params,
310
226
  )
311
227
 
312
- self.send_to_scheduler.send_pyobj(tokenized_obj)
313
- return rid, input_ids
228
+ return tokenized_obj
314
229
 
315
- async def _handle_single_request(
230
+ async def _wait_one_response(
316
231
  self,
317
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
232
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
318
233
  request: Optional[fastapi.Request] = None,
319
- index: Optional[int] = None,
320
- input_id_index: Optional[int] = None,
321
- is_cache_for_prefill: Optional[bool] = False,
322
234
  ):
323
- rid, input_ids = await self._send_single_request(
324
- obj,
325
- index,
326
- input_id_index=input_id_index,
327
- is_cache_for_prefill=is_cache_for_prefill,
328
- )
329
-
330
- # Recv results
235
+ """Wait for the response of one request."""
331
236
  event = asyncio.Event()
332
237
  state = ReqState([], False, event)
333
- self.rid_to_state[rid] = state
334
-
335
- if not is_cache_for_prefill:
336
- async for response in self._wait_for_response(state, obj, rid, request):
337
- yield response
338
- else:
339
- assert self.is_generation
340
- await self._wait_for_cache_prefill_response(state, obj, rid, request)
341
- yield input_ids
342
-
343
- async def _handle_batch_request(
344
- self,
345
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
346
- request: Optional[fastapi.Request] = None,
347
- ):
348
- batch_size = obj.batch_size
349
- if self.is_generation:
350
- parallel_sample_num = obj.parallel_sample_num
351
-
352
- if parallel_sample_num != 1:
353
- # Send prefill requests to cache the common prefix
354
- parallel_sample_num += 1
355
- input_id_result = [] if obj.input_ids is None else None
356
- for i in range(batch_size):
357
- async for input_id in self._handle_single_request(
358
- obj,
359
- request,
360
- index=i,
361
- input_id_index=i,
362
- is_cache_for_prefill=True,
363
- ):
364
- if input_id_result is not None:
365
- input_id_result.append(input_id)
366
- if input_id_result is not None:
367
- obj.input_ids = input_id_result
368
- else:
369
- parallel_sample_num = 1
370
-
371
- # First send out all requests
372
- generators = []
373
- for i in range(batch_size):
374
- for j in range(parallel_sample_num):
375
- if j == 0 and parallel_sample_num != 1:
376
- continue
377
- index = i * parallel_sample_num + j
378
- if parallel_sample_num != 1:
379
- # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
380
- index += batch_size - 1 - i
381
-
382
- rid, _ = await self._send_single_request(
383
- obj, index, input_id_index=i, is_cache_for_prefill=False
384
- )
385
-
386
- event = asyncio.Event()
387
- state = ReqState([], False, event)
388
- self.rid_to_state[rid] = state
389
-
390
- generators.append(
391
- self._wait_for_response(
392
- state,
393
- obj,
394
- rid,
395
- request,
396
- index=index,
397
- response_index=len(generators),
398
- )
399
- )
400
-
401
- # Then process the responses based on streaming option
402
- is_stream = hasattr(obj, "stream") and obj.stream
238
+ self.rid_to_state[obj.rid] = state
403
239
 
404
- tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
405
- output_list = [None] * len(tasks)
406
-
407
- # Fetch results
408
- while tasks:
409
- done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
410
-
411
- for task in done:
412
- cur_index = tasks.index(task)
413
-
414
- try:
415
- result = task.result()
416
-
417
- if is_stream:
418
- yield result
419
- else:
420
- output_list[result["index"]] = result
421
-
422
- tasks[cur_index] = asyncio.create_task(
423
- generators[cur_index].__anext__()
424
- )
425
- except StopAsyncIteration:
426
- del generators[cur_index]
427
- del tasks[cur_index]
428
-
429
- if not is_stream:
430
- yield output_list
431
-
432
- def _validate_input_length(self, input_ids: List[int]):
433
- if len(input_ids) >= self.context_len:
434
- raise ValueError(
435
- f"The input ({len(input_ids)} tokens) is longer than the "
436
- f"model's context length ({self.context_len} tokens)."
437
- )
438
-
439
- def _get_sampling_params(self, sampling_params_data: dict):
440
- sampling_params = SamplingParams(**sampling_params_data)
441
- if sampling_params.max_new_tokens != 0:
442
- sampling_params.normalize(self.tokenizer)
443
- sampling_params.verify()
444
- return sampling_params
445
-
446
- async def _wait_for_response(
447
- self,
448
- state: ReqState,
449
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
450
- rid: str,
451
- request: Optional[fastapi.Request] = None,
452
- index: Optional[int] = None,
453
- response_index: int = 0,
454
- ):
455
240
  while True:
456
241
  try:
457
242
  await asyncio.wait_for(state.event.wait(), timeout=4)
458
243
  except asyncio.TimeoutError:
459
244
  if request is not None and await request.is_disconnected():
460
- for rid in [obj.rid] if obj.is_single else obj.rid:
461
- self.abort_request(rid)
462
- raise ValueError(f"Abort request {rid}")
245
+ self.abort_request(obj.rid)
246
+ raise ValueError(f"Abort request {obj.rid}")
463
247
  continue
464
248
 
465
- if self.is_generation:
249
+ if isinstance(obj, GenerateReqInput):
466
250
  out = self.convert_logprob_style(
467
251
  state.out_list[-1],
468
- obj.return_logprob if index is None else obj.return_logprob[index],
469
- (
470
- obj.top_logprobs_num
471
- if index is None
472
- else obj.top_logprobs_num[index]
473
- ),
252
+ obj.return_logprob,
253
+ obj.top_logprobs_num,
474
254
  obj.return_text_in_logprobs,
475
255
  )
476
- else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
256
+ else: # isinstance(obj, (EmbeddingReqInput,))
477
257
  out = state.out_list[-1]
478
258
 
479
- out["index"] = response_index
480
-
481
- # Log requests
482
- if self.server_args.log_requests and state.finished:
483
- logger.info(f"in={obj}, out={out}")
484
-
485
259
  state.out_list = []
486
260
  if state.finished:
487
- del self.rid_to_state[rid]
261
+ if self.server_args.log_requests:
262
+ # Log requests
263
+ logger.info(f"in={obj}, out={out}")
264
+ del self.rid_to_state[obj.rid]
488
265
  yield out
489
266
  break
490
267
 
491
268
  state.event.clear()
492
269
  yield out
493
270
 
494
- async def _wait_for_cache_prefill_response(
271
+ async def _handle_batch_request(
495
272
  self,
496
- state: ReqState,
497
- obj: GenerateReqInput,
498
- rid: str,
273
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
499
274
  request: Optional[fastapi.Request] = None,
500
275
  ):
501
- while True:
502
- try:
503
- await asyncio.wait_for(state.event.wait(), timeout=4)
504
- break
505
- except asyncio.TimeoutError:
506
- if request is not None and await request.is_disconnected():
507
- for rid in obj.rid:
508
- self.abort_request(rid)
509
- raise ValueError(f"Abort request {rid}")
510
- continue
276
+ batch_size = obj.batch_size
511
277
 
512
- assert state.finished
513
- del self.rid_to_state[rid]
278
+ generators = []
279
+ rids = []
280
+ if getattr(obj, "parallel_sample_num", 1) == 1:
281
+ # Send all requests
282
+ for i in range(batch_size):
283
+ tmp_obj = obj[i]
284
+ tokenized_obj = await self._tokenize_one_request(tmp_obj)
285
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
286
+ generators.append(self._wait_one_response(tmp_obj, request))
287
+ rids.append(tmp_obj.rid)
288
+ else:
289
+ # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
290
+
291
+ # Tokenize all requests
292
+ objs = [obj[i] for i in range(batch_size)]
293
+ tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs))
294
+
295
+ # Cache the common prefix for parallel sampling
296
+ for i in range(batch_size):
297
+ tmp_obj = copy.copy(objs[i])
298
+ tokenized_obj = copy.copy(tokenized_objs[i])
299
+ tokenized_obj.rid = tmp_obj.regenerate_rid()
300
+ tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
301
+ tokenized_obj.sampling_params.max_new_tokens = 0
302
+ tokenized_obj.stream = False
303
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
304
+ await self._wait_one_response(tmp_obj, request).__anext__()
305
+
306
+ # Expand requests, assign new rids for them, and send them
307
+ for i in range(batch_size):
308
+ for _ in range(obj.parallel_sample_num):
309
+ tmp_obj = copy.copy(objs[i])
310
+ tokenized_obj = copy.copy(tokenized_objs[i])
311
+ tokenized_obj.rid = tmp_obj.regenerate_rid()
312
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
313
+ generators.append(self._wait_one_response(tmp_obj, request))
314
+ rids.append(tmp_obj.rid)
315
+
316
+ # Wait for all requests
317
+ is_stream = hasattr(obj, "stream") and obj.stream
318
+ if not is_stream:
319
+ outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
320
+ yield outputs
321
+ else:
322
+ rid_to_index = {rid: i for i, rid in enumerate(rids)}
323
+ task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
324
+ while task_map:
325
+ done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
326
+
327
+ for task in done:
328
+ gen = task_map.pop(task)
329
+ try:
330
+ result = task.result()
331
+ result["index"] = rid_to_index[result["meta_info"]["id"]]
332
+ yield result
333
+ new_task = asyncio.create_task(gen.__anext__())
334
+ task_map[new_task] = gen
335
+ except StopAsyncIteration:
336
+ pass
514
337
 
515
338
  def flush_cache(self):
516
339
  req = FlushCacheReq()
@@ -531,6 +354,25 @@ class TokenizerManager:
531
354
  req = ProfileReq.STOP_PROFILE
532
355
  self.send_to_scheduler.send_pyobj(req)
533
356
 
357
+ async def get_memory_pool_size(self):
358
+ if self.to_create_loop:
359
+ self.create_handle_loop()
360
+
361
+ req = GetMemPoolSizeReq()
362
+
363
+ self.send_to_scheduler.send_pyobj(req)
364
+ self.mem_pool_size = asyncio.Future()
365
+
366
+ # FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
367
+ if self.server_args.dp_size == 1:
368
+ res = await self.mem_pool_size
369
+ return res.size
370
+ else: # self.server_args.dp_size > 1
371
+ self.mem_pool_size_tmp = []
372
+ res = await self.mem_pool_size
373
+ ret = [r.size for r in res]
374
+ return ret
375
+
534
376
  async def update_weights(
535
377
  self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
536
378
  ):
@@ -542,25 +384,41 @@ class TokenizerManager:
542
384
  obj.load_format = self.server_args.load_format
543
385
 
544
386
  if not self.model_update_lock.locked():
387
+
545
388
  async with self.model_update_lock:
546
389
  # wait for the previous generation requests to finish
547
390
  while len(self.rid_to_state) > 0:
548
391
  await asyncio.sleep(0.001)
549
392
  self.send_to_scheduler.send_pyobj(obj)
550
393
  self.model_update_result = asyncio.Future()
551
- result = await self.model_update_result
552
- if result.success:
553
- self.server_args.model_path = obj.model_path
554
- self.server_args.load_format = obj.load_format
555
- self.model_path = obj.model_path
556
- return result.success, result.message
394
+
395
+ if self.server_args.dp_size == 1:
396
+ result = await self.model_update_result
397
+ if result.success:
398
+ self.server_args.model_path = obj.model_path
399
+ self.server_args.load_format = obj.load_format
400
+ self.model_path = obj.model_path
401
+ return result.success, result.message
402
+ else: # self.server_args.dp_size > 1
403
+ self.model_update_tmp = []
404
+ result = await self.model_update_result
405
+
406
+ all_success = all([r.success for r in result])
407
+ if all_success is True:
408
+ self.server_args.model_path = obj.model_path
409
+ self.server_args.load_format = obj.load_format
410
+ self.model_path = obj.model_path
411
+ all_message = [r.message for r in result]
412
+ all_message = " | ".join(all_message)
413
+ return all_success, all_message
414
+
557
415
  else:
558
416
  return False, "Another update is in progress. Please try again later."
559
417
 
560
418
  def create_abort_task(self, obj: GenerateReqInput):
561
419
  # Abort the request if the client is disconnected.
562
420
  async def abort_request():
563
- await asyncio.sleep(3)
421
+ await asyncio.sleep(1)
564
422
  if obj.is_single:
565
423
  self.abort_request(obj.rid)
566
424
  else:
@@ -579,6 +437,28 @@ class TokenizerManager:
579
437
  loop = asyncio.get_event_loop()
580
438
  loop.create_task(self.handle_loop())
581
439
 
440
+ signal_handler = SignalHandler(self)
441
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
442
+ loop.create_task(self.sigterm_watchdog())
443
+
444
+ async def sigterm_watchdog(self):
445
+ while not self.gracefully_exit:
446
+ await asyncio.sleep(60)
447
+
448
+ # drain requests
449
+ while True:
450
+ remain_num_req = len(self.rid_to_state)
451
+ logger.info(
452
+ f"Gracefully exiting... remaining number of requests {remain_num_req}"
453
+ )
454
+ if remain_num_req > 0:
455
+ await asyncio.sleep(5)
456
+ else:
457
+ break
458
+
459
+ kill_child_process(include_self=True)
460
+ sys.exit(-1)
461
+
582
462
  async def handle_loop(self):
583
463
  """The event loop that handles requests"""
584
464
 
@@ -588,7 +468,22 @@ class TokenizerManager:
588
468
  ] = await self.recv_from_detokenizer.recv_pyobj()
589
469
 
590
470
  if isinstance(recv_obj, UpdateWeightReqOutput):
591
- self.model_update_result.set_result(recv_obj)
471
+ if self.server_args.dp_size == 1:
472
+ self.model_update_result.set_result(recv_obj)
473
+ else: # self.server_args.dp_size > 1
474
+ self.model_update_tmp.append(recv_obj)
475
+ # set future if the all results are recevied
476
+ if len(self.model_update_tmp) == self.server_args.dp_size:
477
+ self.model_update_result.set_result(self.model_update_tmp)
478
+ continue
479
+ elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
480
+ if self.server_args.dp_size == 1:
481
+ self.mem_pool_size.set_result(recv_obj)
482
+ else: # self.sever_args.dp_size > 1
483
+ self.mem_pool_size_tmp.append(recv_obj)
484
+ # set future if the all results are received
485
+ if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
486
+ self.mem_pool_size.set_result(self.mem_pool_size_tmp)
592
487
  continue
593
488
 
594
489
  assert isinstance(
@@ -607,14 +502,10 @@ class TokenizerManager:
607
502
  "meta_info": recv_obj.meta_info[i],
608
503
  }
609
504
  elif isinstance(recv_obj, BatchTokenIDOut):
610
- read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
611
505
  out_dict = {
612
- "token_ids": recv_obj.decode_ids[
613
- read_start : recv_obj.read_offsets[i]
614
- ],
506
+ "token_ids": recv_obj.output_ids[i],
615
507
  "meta_info": recv_obj.meta_info[i],
616
508
  }
617
-
618
509
  else:
619
510
  assert isinstance(recv_obj, BatchEmbeddingOut)
620
511
  out_dict = {
@@ -666,7 +557,7 @@ class TokenizerManager:
666
557
  token_texts = self.tokenizer.batch_decode(token_ids)
667
558
  return [
668
559
  (logprob, token_id, token_text)
669
- for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
560
+ for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
670
561
  ]
671
562
 
672
563
  def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
@@ -678,3 +569,14 @@ class TokenizerManager:
678
569
  token_top_logprobs, decode_to_text
679
570
  )
680
571
  return top_logprobs
572
+
573
+
574
+ class SignalHandler:
575
+ def __init__(self, tokenizer_manager):
576
+ self.tokenizer_manager = tokenizer_manager
577
+
578
+ def signal_handler(self, signum=None, frame=None):
579
+ logger.warning(
580
+ f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
581
+ )
582
+ self.tokenizer_manager.gracefully_exit = True