sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +51 -13
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/grammar.py +190 -0
  14. sglang/srt/hf_transformers_utils.py +6 -5
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  16. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  17. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  18. sglang/srt/layers/fused_moe/layer.py +28 -0
  19. sglang/srt/layers/quantization/base_config.py +16 -1
  20. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  21. sglang/srt/managers/data_parallel_controller.py +7 -6
  22. sglang/srt/managers/detokenizer_manager.py +9 -11
  23. sglang/srt/managers/image_processor.py +4 -3
  24. sglang/srt/managers/io_struct.py +70 -78
  25. sglang/srt/managers/schedule_batch.py +33 -49
  26. sglang/srt/managers/schedule_policy.py +24 -13
  27. sglang/srt/managers/scheduler.py +137 -80
  28. sglang/srt/managers/tokenizer_manager.py +224 -336
  29. sglang/srt/managers/tp_worker.py +5 -5
  30. sglang/srt/mem_cache/flush_cache.py +1 -1
  31. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  32. sglang/srt/model_executor/model_runner.py +8 -17
  33. sglang/srt/models/baichuan.py +4 -4
  34. sglang/srt/models/chatglm.py +4 -4
  35. sglang/srt/models/commandr.py +1 -1
  36. sglang/srt/models/dbrx.py +5 -5
  37. sglang/srt/models/deepseek.py +4 -4
  38. sglang/srt/models/deepseek_v2.py +4 -4
  39. sglang/srt/models/exaone.py +4 -4
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +1 -1
  42. sglang/srt/models/gpt2.py +287 -0
  43. sglang/srt/models/gpt_bigcode.py +1 -1
  44. sglang/srt/models/grok.py +4 -4
  45. sglang/srt/models/internlm2.py +4 -4
  46. sglang/srt/models/llama.py +15 -7
  47. sglang/srt/models/llama_embedding.py +2 -10
  48. sglang/srt/models/llama_reward.py +5 -0
  49. sglang/srt/models/minicpm.py +4 -4
  50. sglang/srt/models/minicpm3.py +4 -4
  51. sglang/srt/models/mixtral.py +7 -5
  52. sglang/srt/models/mixtral_quant.py +4 -4
  53. sglang/srt/models/mllama.py +5 -5
  54. sglang/srt/models/olmo.py +4 -4
  55. sglang/srt/models/olmoe.py +4 -4
  56. sglang/srt/models/qwen.py +4 -4
  57. sglang/srt/models/qwen2.py +4 -4
  58. sglang/srt/models/qwen2_moe.py +4 -4
  59. sglang/srt/models/qwen2_vl.py +4 -8
  60. sglang/srt/models/stablelm.py +4 -4
  61. sglang/srt/models/torch_native_llama.py +4 -4
  62. sglang/srt/models/xverse.py +4 -4
  63. sglang/srt/models/xverse_moe.py +4 -4
  64. sglang/srt/openai_api/adapter.py +52 -66
  65. sglang/srt/sampling/sampling_batch_info.py +7 -13
  66. sglang/srt/server.py +31 -35
  67. sglang/srt/server_args.py +34 -5
  68. sglang/srt/utils.py +40 -56
  69. sglang/test/runners.py +2 -1
  70. sglang/test/test_utils.py +73 -25
  71. sglang/utils.py +62 -1
  72. sglang/version.py +1 -1
  73. sglang-0.3.5.dist-info/METADATA +344 -0
  74. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
  75. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  76. sglang-0.3.4.post2.dist-info/METADATA +0 -899
  77. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  78. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,12 @@ limitations under the License.
16
16
  """TokenizerManager is a process that tokenizes the text."""
17
17
 
18
18
  import asyncio
19
+ import copy
19
20
  import dataclasses
20
- import json
21
21
  import logging
22
22
  import os
23
+ import signal
24
+ import sys
23
25
  from typing import Dict, List, Optional, Tuple, Union
24
26
 
25
27
  import fastapi
@@ -28,12 +30,8 @@ import zmq
28
30
  import zmq.asyncio
29
31
  from fastapi import BackgroundTasks
30
32
 
31
- from sglang.srt.hf_transformers_utils import (
32
- get_config,
33
- get_context_length,
34
- get_processor,
35
- get_tokenizer,
36
- )
33
+ from sglang.srt.configs.model_config import ModelConfig
34
+ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
37
35
  from sglang.srt.managers.image_processor import (
38
36
  get_dummy_image_processor,
39
37
  get_image_processor,
@@ -49,16 +47,14 @@ from sglang.srt.managers.io_struct import (
49
47
  GetMemPoolSizeReq,
50
48
  GetMemPoolSizeReqOutput,
51
49
  ProfileReq,
52
- RewardReqInput,
53
50
  TokenizedEmbeddingReqInput,
54
51
  TokenizedGenerateReqInput,
55
- TokenizedRewardReqInput,
56
52
  UpdateWeightReqInput,
57
53
  UpdateWeightReqOutput,
58
54
  )
59
55
  from sglang.srt.sampling.sampling_params import SamplingParams
60
56
  from sglang.srt.server_args import PortArgs, ServerArgs
61
- from sglang.srt.utils import is_generation_model, is_multimodal_model
57
+ from sglang.srt.utils import get_zmq_socket, kill_child_process
62
58
 
63
59
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
64
60
 
@@ -82,30 +78,32 @@ class TokenizerManager:
82
78
  server_args: ServerArgs,
83
79
  port_args: PortArgs,
84
80
  ):
81
+ # Parse args
85
82
  self.server_args = server_args
86
83
 
87
84
  # Init inter-process communication
88
85
  context = zmq.asyncio.Context(2)
89
- self.recv_from_detokenizer = context.socket(zmq.PULL)
90
- self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}")
91
-
92
- self.send_to_scheduler = context.socket(zmq.PUSH)
93
- self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
86
+ self.recv_from_detokenizer = get_zmq_socket(
87
+ context, zmq.PULL, port_args.tokenizer_ipc_name
88
+ )
89
+ self.send_to_scheduler = get_zmq_socket(
90
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name
91
+ )
94
92
 
95
93
  # Read model args
96
94
  self.model_path = server_args.model_path
97
95
  self.served_model_name = server_args.served_model_name
98
- self.hf_config = get_config(
99
- self.model_path,
96
+ self.model_config = ModelConfig(
97
+ server_args.model_path,
100
98
  trust_remote_code=server_args.trust_remote_code,
101
- model_override_args=json.loads(server_args.json_model_override_args),
102
- )
103
- self.is_generation = is_generation_model(
104
- self.hf_config.architectures, self.server_args.is_embedding
105
- )
106
- self.context_len = server_args.context_length or get_context_length(
107
- self.hf_config
99
+ context_length=server_args.context_length,
100
+ model_override_args=server_args.json_model_override_args,
101
+ is_embedding=server_args.is_embedding,
108
102
  )
103
+
104
+ self.is_generation = self.model_config.is_generation
105
+ self.context_len = self.model_config.context_len
106
+
109
107
  # Create image processor placeholder
110
108
  self.image_processor = get_dummy_image_processor()
111
109
 
@@ -113,7 +111,7 @@ class TokenizerManager:
113
111
  if server_args.skip_tokenizer_init:
114
112
  self.tokenizer = self.processor = None
115
113
  else:
116
- if is_multimodal_model(self.hf_config.architectures):
114
+ if self.model_config.is_multimodal:
117
115
  self.processor = get_processor(
118
116
  server_args.tokenizer_path,
119
117
  tokenizer_mode=server_args.tokenizer_mode,
@@ -124,7 +122,7 @@ class TokenizerManager:
124
122
 
125
123
  # We want to parallelize the image pre-processing so we create an executor for it
126
124
  self.image_processor = get_image_processor(
127
- self.hf_config, server_args, self.processor
125
+ self.model_config.hf_config, server_args, self.processor
128
126
  )
129
127
  else:
130
128
  self.tokenizer = get_tokenizer(
@@ -141,9 +139,12 @@ class TokenizerManager:
141
139
  self.model_update_lock = asyncio.Lock()
142
140
  self.model_update_result = None
143
141
 
142
+ # Others
143
+ self.gracefully_exit = False
144
+
144
145
  async def generate_request(
145
146
  self,
146
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
147
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
147
148
  request: Optional[fastapi.Request] = None,
148
149
  ):
149
150
  if self.to_create_loop:
@@ -154,133 +155,58 @@ class TokenizerManager:
154
155
 
155
156
  if isinstance(obj, EmbeddingReqInput) and self.is_generation:
156
157
  raise ValueError(
157
- "This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
158
+ "This model does not appear to be an embedding model by default. "
159
+ "Please add `--is-embedding` when launching the server or try another model."
158
160
  )
159
161
 
160
- obj.post_init()
162
+ obj.normalize_batch_and_arguments()
161
163
  is_single = obj.is_single
162
164
  if is_single:
163
- async for response in self._handle_single_request(obj, request):
165
+ tokenized_obj = await self._tokenize_one_request(obj)
166
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
167
+ async for response in self._wait_one_response(obj, request):
164
168
  yield response
165
169
  else:
166
170
  async for response in self._handle_batch_request(obj, request):
167
171
  yield response
168
172
 
169
- async def _send_single_request(
173
+ async def _tokenize_one_request(
170
174
  self,
171
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
172
- index: Optional[int] = None,
173
- input_id_index: Optional[int] = None,
174
- is_cache_for_prefill: Optional[bool] = False,
175
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
175
176
  ):
176
- if not is_cache_for_prefill: # The normal case with a single prompt
177
- if index is None:
178
- rid = obj.rid
179
- if hasattr(obj, "conv"):
180
- # reward model
181
- conv = obj.conv
182
- input_text = self.tokenizer.apply_chat_template(
183
- conv, tokenize=False
184
- )
185
- input_ids = self.tokenizer.encode(input_text)
186
- elif obj.input_ids is None:
187
- input_text = obj.text
188
- input_ids = self.tokenizer.encode(input_text)
189
- else:
190
- input_text = obj.text if obj.text is not None else None
191
- input_ids = obj.input_ids
192
-
193
- sampling_params = self._get_sampling_params(obj.sampling_params)
194
- if self.is_generation:
195
- image_inputs = await self.image_processor.process_images_async(
196
- obj.image_data, input_text or input_ids, obj
197
- )
198
- if image_inputs and "input_ids" in image_inputs:
199
- input_ids = image_inputs["input_ids"]
200
- return_logprob = obj.return_logprob
201
- logprob_start_len = obj.logprob_start_len
202
- top_logprobs_num = obj.top_logprobs_num
203
- else:
204
- rid = obj.rid[index]
205
- if hasattr(obj, "conv"):
206
- # reward model
207
- conv = obj.conv[index]
208
- input_text = self.tokenizer.apply_chat_template(
209
- conv, tokenize=False
210
- )
211
- input_ids = self.tokenizer.encode(input_text)
212
- elif obj.input_ids is None:
213
- input_text = obj.text[input_id_index]
214
- input_ids = self.tokenizer.encode(input_text)
215
- else:
216
- input_text = (
217
- obj.text[input_id_index] if obj.text is not None else None
218
- )
219
- input_ids = obj.input_ids[input_id_index]
220
-
221
- sampling_params = self._get_sampling_params(obj.sampling_params[index])
222
- if self.is_generation:
223
- image_inputs = await self.image_processor.process_images_async(
224
- obj.image_data[index], input_text or input_ids, obj
225
- )
226
- if image_inputs and "input_ids" in image_inputs:
227
- input_ids = image_inputs["input_ids"]
228
- return_logprob = obj.return_logprob[index]
229
- logprob_start_len = obj.logprob_start_len[index]
230
- top_logprobs_num = obj.top_logprobs_num[index]
231
-
232
- self._validate_input_length(input_ids)
233
-
234
- else: # A prefill request to cache the common prompt for parallel sampling
235
- assert self.is_generation
236
- if obj.text is not None:
237
- if isinstance(obj.text, list):
238
- input_text = obj.text[input_id_index]
239
- rid = obj.rid[index]
240
- else:
241
- input_text = obj.text
242
- rid = obj.rid[0]
243
- if self.tokenizer is not None:
244
- input_ids = self.tokenizer.encode(input_text)
245
- else:
246
- assert obj.input_ids is not None
247
- input_ids = obj.input_ids
248
- if isinstance(obj.input_ids, list) and isinstance(
249
- obj.input_ids[0], list
250
- ):
251
- # when obj["input_ids"] is List[List[int]]
252
- input_ids = obj.input_ids[input_id_index]
253
- rid = obj.rid[index]
254
- else:
255
- input_ids = obj.input_ids
256
- rid = obj.rid[0]
257
- else:
258
- input_text = None
259
- if isinstance(obj.input_ids, list) and isinstance(
260
- obj.input_ids[0], list
261
- ):
262
- # when obj["input_ids"] is List[List[int]]
263
- input_ids = obj.input_ids[input_id_index]
264
- rid = obj.rid[index]
265
- else:
266
- input_ids = obj.input_ids
267
- rid = obj.rid[0]
177
+ """Tokenize one request."""
178
+ # Tokenize
179
+ input_text = obj.text
180
+ if obj.input_ids is None:
181
+ input_ids = self.tokenizer.encode(input_text)
182
+ else:
183
+ input_ids = obj.input_ids
268
184
 
269
- sampling_params = SamplingParams(**obj.sampling_params[0])
270
- sampling_params.max_new_tokens = 0
185
+ if self.is_generation:
271
186
  image_inputs = await self.image_processor.process_images_async(
272
- obj.image_data[0], input_text or input_ids, obj
187
+ obj.image_data, input_text or input_ids, obj
273
188
  )
274
189
  if image_inputs and "input_ids" in image_inputs:
275
190
  input_ids = image_inputs["input_ids"]
276
- return_logprob = obj.return_logprob[0]
277
- logprob_start_len = obj.logprob_start_len[0]
278
- top_logprobs_num = obj.top_logprobs_num[0]
191
+ return_logprob = obj.return_logprob
192
+ logprob_start_len = obj.logprob_start_len
193
+ top_logprobs_num = obj.top_logprobs_num
279
194
 
280
- # Send to the controller
281
- if self.is_generation:
195
+ if len(input_ids) >= self.context_len:
196
+ raise ValueError(
197
+ f"The input ({len(input_ids)} tokens) is longer than the "
198
+ f"model's context length ({self.context_len} tokens)."
199
+ )
200
+
201
+ # Parse sampling parameters
202
+ sampling_params = SamplingParams(**obj.sampling_params)
203
+ sampling_params.normalize(self.tokenizer)
204
+ sampling_params.verify()
205
+
206
+ # Build return object
207
+ if isinstance(obj, GenerateReqInput):
282
208
  tokenized_obj = TokenizedGenerateReqInput(
283
- rid,
209
+ obj.rid,
284
210
  input_text,
285
211
  input_ids,
286
212
  image_inputs,
@@ -289,230 +215,125 @@ class TokenizerManager:
289
215
  logprob_start_len,
290
216
  top_logprobs_num,
291
217
  obj.stream,
292
- (
293
- obj.lora_path[input_id_index]
294
- if isinstance(obj.lora_path, list)
295
- else obj.lora_path
296
- ),
218
+ obj.lora_path
297
219
  )
298
220
  elif isinstance(obj, EmbeddingReqInput):
299
221
  tokenized_obj = TokenizedEmbeddingReqInput(
300
- rid,
301
- input_text,
302
- input_ids,
303
- sampling_params,
304
- )
305
- else:
306
- assert isinstance(obj, RewardReqInput)
307
- tokenized_obj = TokenizedRewardReqInput(
308
- rid,
222
+ obj.rid,
309
223
  input_text,
310
224
  input_ids,
311
225
  sampling_params,
312
226
  )
313
227
 
314
- self.send_to_scheduler.send_pyobj(tokenized_obj)
315
- return rid, input_ids
228
+ return tokenized_obj
316
229
 
317
- async def _handle_single_request(
230
+ async def _wait_one_response(
318
231
  self,
319
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
232
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
320
233
  request: Optional[fastapi.Request] = None,
321
- index: Optional[int] = None,
322
- input_id_index: Optional[int] = None,
323
- is_cache_for_prefill: Optional[bool] = False,
324
234
  ):
325
- rid, input_ids = await self._send_single_request(
326
- obj,
327
- index,
328
- input_id_index=input_id_index,
329
- is_cache_for_prefill=is_cache_for_prefill,
330
- )
331
-
332
- # Recv results
235
+ """Wait for the response of one request."""
333
236
  event = asyncio.Event()
334
237
  state = ReqState([], False, event)
335
- self.rid_to_state[rid] = state
336
-
337
- if not is_cache_for_prefill:
338
- async for response in self._wait_for_response(state, obj, rid, request):
339
- yield response
340
- else:
341
- assert self.is_generation
342
- await self._wait_for_cache_prefill_response(state, obj, rid, request)
343
- yield input_ids
344
-
345
- async def _handle_batch_request(
346
- self,
347
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
348
- request: Optional[fastapi.Request] = None,
349
- ):
350
- batch_size = obj.batch_size
351
- if self.is_generation:
352
- parallel_sample_num = obj.parallel_sample_num
353
-
354
- if parallel_sample_num != 1:
355
- # Send prefill requests to cache the common prefix
356
- parallel_sample_num += 1
357
- input_id_result = [] if obj.input_ids is None else None
358
- for i in range(batch_size):
359
- async for input_id in self._handle_single_request(
360
- obj,
361
- request,
362
- index=i,
363
- input_id_index=i,
364
- is_cache_for_prefill=True,
365
- ):
366
- if input_id_result is not None:
367
- input_id_result.append(input_id)
368
- if input_id_result is not None:
369
- obj.input_ids = input_id_result
370
- else:
371
- parallel_sample_num = 1
372
-
373
- # First send out all requests
374
- generators = []
375
- for i in range(batch_size):
376
- for j in range(parallel_sample_num):
377
- if j == 0 and parallel_sample_num != 1:
378
- continue
379
- index = i * parallel_sample_num + j
380
- if parallel_sample_num != 1:
381
- # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
382
- index += batch_size - 1 - i
383
-
384
- rid, _ = await self._send_single_request(
385
- obj, index, input_id_index=i, is_cache_for_prefill=False
386
- )
387
-
388
- event = asyncio.Event()
389
- state = ReqState([], False, event)
390
- self.rid_to_state[rid] = state
391
-
392
- generators.append(
393
- self._wait_for_response(
394
- state,
395
- obj,
396
- rid,
397
- request,
398
- index=index,
399
- response_index=len(generators),
400
- )
401
- )
402
-
403
- # Then process the responses based on streaming option
404
- is_stream = hasattr(obj, "stream") and obj.stream
405
-
406
- tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
407
- output_list = [None] * len(tasks)
238
+ self.rid_to_state[obj.rid] = state
408
239
 
409
- # Fetch results
410
- while tasks:
411
- done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
412
-
413
- for task in done:
414
- cur_index = tasks.index(task)
415
-
416
- try:
417
- result = task.result()
418
-
419
- if is_stream:
420
- yield result
421
- else:
422
- output_list[result["index"]] = result
423
-
424
- tasks[cur_index] = asyncio.create_task(
425
- generators[cur_index].__anext__()
426
- )
427
- except StopAsyncIteration:
428
- del generators[cur_index]
429
- del tasks[cur_index]
430
-
431
- if not is_stream:
432
- yield output_list
433
-
434
- def _validate_input_length(self, input_ids: List[int]):
435
- if len(input_ids) >= self.context_len:
436
- raise ValueError(
437
- f"The input ({len(input_ids)} tokens) is longer than the "
438
- f"model's context length ({self.context_len} tokens)."
439
- )
440
-
441
- def _get_sampling_params(self, sampling_params_data: dict):
442
- sampling_params = SamplingParams(**sampling_params_data)
443
- if sampling_params.max_new_tokens != 0:
444
- sampling_params.normalize(self.tokenizer)
445
- sampling_params.verify()
446
- return sampling_params
447
-
448
- async def _wait_for_response(
449
- self,
450
- state: ReqState,
451
- obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
452
- rid: str,
453
- request: Optional[fastapi.Request] = None,
454
- index: Optional[int] = None,
455
- response_index: int = 0,
456
- ):
457
240
  while True:
458
241
  try:
459
242
  await asyncio.wait_for(state.event.wait(), timeout=4)
460
243
  except asyncio.TimeoutError:
461
244
  if request is not None and await request.is_disconnected():
462
- for rid in [obj.rid] if obj.is_single else obj.rid:
463
- self.abort_request(rid)
464
- raise ValueError(f"Abort request {rid}")
245
+ self.abort_request(obj.rid)
246
+ raise ValueError(f"Abort request {obj.rid}")
465
247
  continue
466
248
 
467
- if self.is_generation:
249
+ if isinstance(obj, GenerateReqInput):
468
250
  out = self.convert_logprob_style(
469
251
  state.out_list[-1],
470
- obj.return_logprob if index is None else obj.return_logprob[index],
471
- (
472
- obj.top_logprobs_num
473
- if index is None
474
- else obj.top_logprobs_num[index]
475
- ),
252
+ obj.return_logprob,
253
+ obj.top_logprobs_num,
476
254
  obj.return_text_in_logprobs,
477
255
  )
478
- else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
256
+ else: # isinstance(obj, (EmbeddingReqInput,))
479
257
  out = state.out_list[-1]
480
258
 
481
- out["index"] = response_index
482
-
483
- # Log requests
484
- if self.server_args.log_requests and state.finished:
485
- logger.info(f"in={obj}, out={out}")
486
-
487
259
  state.out_list = []
488
260
  if state.finished:
489
- del self.rid_to_state[rid]
261
+ if self.server_args.log_requests:
262
+ # Log requests
263
+ logger.info(f"in={obj}, out={out}")
264
+ del self.rid_to_state[obj.rid]
490
265
  yield out
491
266
  break
492
267
 
493
268
  state.event.clear()
494
269
  yield out
495
270
 
496
- async def _wait_for_cache_prefill_response(
271
+ async def _handle_batch_request(
497
272
  self,
498
- state: ReqState,
499
- obj: GenerateReqInput,
500
- rid: str,
273
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
501
274
  request: Optional[fastapi.Request] = None,
502
275
  ):
503
- while True:
504
- try:
505
- await asyncio.wait_for(state.event.wait(), timeout=4)
506
- break
507
- except asyncio.TimeoutError:
508
- if request is not None and await request.is_disconnected():
509
- for rid in obj.rid:
510
- self.abort_request(rid)
511
- raise ValueError(f"Abort request {rid}")
512
- continue
276
+ batch_size = obj.batch_size
513
277
 
514
- assert state.finished
515
- del self.rid_to_state[rid]
278
+ generators = []
279
+ rids = []
280
+ if getattr(obj, "parallel_sample_num", 1) == 1:
281
+ # Send all requests
282
+ for i in range(batch_size):
283
+ tmp_obj = obj[i]
284
+ tokenized_obj = await self._tokenize_one_request(tmp_obj)
285
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
286
+ generators.append(self._wait_one_response(tmp_obj, request))
287
+ rids.append(tmp_obj.rid)
288
+ else:
289
+ # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
290
+
291
+ # Tokenize all requests
292
+ objs = [obj[i] for i in range(batch_size)]
293
+ tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs))
294
+
295
+ # Cache the common prefix for parallel sampling
296
+ for i in range(batch_size):
297
+ tmp_obj = copy.copy(objs[i])
298
+ tokenized_obj = copy.copy(tokenized_objs[i])
299
+ tokenized_obj.rid = tmp_obj.regenerate_rid()
300
+ tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
301
+ tokenized_obj.sampling_params.max_new_tokens = 0
302
+ tokenized_obj.stream = False
303
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
304
+ await self._wait_one_response(tmp_obj, request).__anext__()
305
+
306
+ # Expand requests, assign new rids for them, and send them
307
+ for i in range(batch_size):
308
+ for _ in range(obj.parallel_sample_num):
309
+ tmp_obj = copy.copy(objs[i])
310
+ tokenized_obj = copy.copy(tokenized_objs[i])
311
+ tokenized_obj.rid = tmp_obj.regenerate_rid()
312
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
313
+ generators.append(self._wait_one_response(tmp_obj, request))
314
+ rids.append(tmp_obj.rid)
315
+
316
+ # Wait for all requests
317
+ is_stream = hasattr(obj, "stream") and obj.stream
318
+ if not is_stream:
319
+ outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
320
+ yield outputs
321
+ else:
322
+ rid_to_index = {rid: i for i, rid in enumerate(rids)}
323
+ task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
324
+ while task_map:
325
+ done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
326
+
327
+ for task in done:
328
+ gen = task_map.pop(task)
329
+ try:
330
+ result = task.result()
331
+ result["index"] = rid_to_index[result["meta_info"]["id"]]
332
+ yield result
333
+ new_task = asyncio.create_task(gen.__anext__())
334
+ task_map[new_task] = gen
335
+ except StopAsyncIteration:
336
+ pass
516
337
 
517
338
  def flush_cache(self):
518
339
  req = FlushCacheReq()
@@ -538,9 +359,19 @@ class TokenizerManager:
538
359
  self.create_handle_loop()
539
360
 
540
361
  req = GetMemPoolSizeReq()
362
+
541
363
  self.send_to_scheduler.send_pyobj(req)
542
364
  self.mem_pool_size = asyncio.Future()
543
- return await self.mem_pool_size
365
+
366
+ # FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
367
+ if self.server_args.dp_size == 1:
368
+ res = await self.mem_pool_size
369
+ return res.size
370
+ else: # self.server_args.dp_size > 1
371
+ self.mem_pool_size_tmp = []
372
+ res = await self.mem_pool_size
373
+ ret = [r.size for r in res]
374
+ return ret
544
375
 
545
376
  async def update_weights(
546
377
  self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
@@ -553,25 +384,41 @@ class TokenizerManager:
553
384
  obj.load_format = self.server_args.load_format
554
385
 
555
386
  if not self.model_update_lock.locked():
387
+
556
388
  async with self.model_update_lock:
557
389
  # wait for the previous generation requests to finish
558
390
  while len(self.rid_to_state) > 0:
559
391
  await asyncio.sleep(0.001)
560
392
  self.send_to_scheduler.send_pyobj(obj)
561
393
  self.model_update_result = asyncio.Future()
562
- result = await self.model_update_result
563
- if result.success:
564
- self.server_args.model_path = obj.model_path
565
- self.server_args.load_format = obj.load_format
566
- self.model_path = obj.model_path
567
- return result.success, result.message
394
+
395
+ if self.server_args.dp_size == 1:
396
+ result = await self.model_update_result
397
+ if result.success:
398
+ self.server_args.model_path = obj.model_path
399
+ self.server_args.load_format = obj.load_format
400
+ self.model_path = obj.model_path
401
+ return result.success, result.message
402
+ else: # self.server_args.dp_size > 1
403
+ self.model_update_tmp = []
404
+ result = await self.model_update_result
405
+
406
+ all_success = all([r.success for r in result])
407
+ if all_success is True:
408
+ self.server_args.model_path = obj.model_path
409
+ self.server_args.load_format = obj.load_format
410
+ self.model_path = obj.model_path
411
+ all_message = [r.message for r in result]
412
+ all_message = " | ".join(all_message)
413
+ return all_success, all_message
414
+
568
415
  else:
569
416
  return False, "Another update is in progress. Please try again later."
570
417
 
571
418
  def create_abort_task(self, obj: GenerateReqInput):
572
419
  # Abort the request if the client is disconnected.
573
420
  async def abort_request():
574
- await asyncio.sleep(3)
421
+ await asyncio.sleep(1)
575
422
  if obj.is_single:
576
423
  self.abort_request(obj.rid)
577
424
  else:
@@ -590,6 +437,28 @@ class TokenizerManager:
590
437
  loop = asyncio.get_event_loop()
591
438
  loop.create_task(self.handle_loop())
592
439
 
440
+ signal_handler = SignalHandler(self)
441
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
442
+ loop.create_task(self.sigterm_watchdog())
443
+
444
+ async def sigterm_watchdog(self):
445
+ while not self.gracefully_exit:
446
+ await asyncio.sleep(60)
447
+
448
+ # drain requests
449
+ while True:
450
+ remain_num_req = len(self.rid_to_state)
451
+ logger.info(
452
+ f"Gracefully exiting... remaining number of requests {remain_num_req}"
453
+ )
454
+ if remain_num_req > 0:
455
+ await asyncio.sleep(5)
456
+ else:
457
+ break
458
+
459
+ kill_child_process(include_self=True)
460
+ sys.exit(-1)
461
+
593
462
  async def handle_loop(self):
594
463
  """The event loop that handles requests"""
595
464
 
@@ -599,10 +468,22 @@ class TokenizerManager:
599
468
  ] = await self.recv_from_detokenizer.recv_pyobj()
600
469
 
601
470
  if isinstance(recv_obj, UpdateWeightReqOutput):
602
- self.model_update_result.set_result(recv_obj)
471
+ if self.server_args.dp_size == 1:
472
+ self.model_update_result.set_result(recv_obj)
473
+ else: # self.server_args.dp_size > 1
474
+ self.model_update_tmp.append(recv_obj)
475
+ # set future if the all results are recevied
476
+ if len(self.model_update_tmp) == self.server_args.dp_size:
477
+ self.model_update_result.set_result(self.model_update_tmp)
603
478
  continue
604
479
  elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
605
- self.mem_pool_size.set_result(recv_obj)
480
+ if self.server_args.dp_size == 1:
481
+ self.mem_pool_size.set_result(recv_obj)
482
+ else: # self.sever_args.dp_size > 1
483
+ self.mem_pool_size_tmp.append(recv_obj)
484
+ # set future if the all results are received
485
+ if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
486
+ self.mem_pool_size.set_result(self.mem_pool_size_tmp)
606
487
  continue
607
488
 
608
489
  assert isinstance(
@@ -621,14 +502,10 @@ class TokenizerManager:
621
502
  "meta_info": recv_obj.meta_info[i],
622
503
  }
623
504
  elif isinstance(recv_obj, BatchTokenIDOut):
624
- read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
625
505
  out_dict = {
626
- "token_ids": recv_obj.decode_ids[
627
- read_start : recv_obj.read_offsets[i]
628
- ],
506
+ "token_ids": recv_obj.output_ids[i],
629
507
  "meta_info": recv_obj.meta_info[i],
630
508
  }
631
-
632
509
  else:
633
510
  assert isinstance(recv_obj, BatchEmbeddingOut)
634
511
  out_dict = {
@@ -680,7 +557,7 @@ class TokenizerManager:
680
557
  token_texts = self.tokenizer.batch_decode(token_ids)
681
558
  return [
682
559
  (logprob, token_id, token_text)
683
- for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
560
+ for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
684
561
  ]
685
562
 
686
563
  def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
@@ -692,3 +569,14 @@ class TokenizerManager:
692
569
  token_top_logprobs, decode_to_text
693
570
  )
694
571
  return top_logprobs
572
+
573
+
574
+ class SignalHandler:
575
+ def __init__(self, tokenizer_manager):
576
+ self.tokenizer_manager = tokenizer_manager
577
+
578
+ def signal_handler(self, signum=None, frame=None):
579
+ logger.warning(
580
+ f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
581
+ )
582
+ self.tokenizer_manager.gracefully_exit = True