sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,209 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import TYPE_CHECKING, List
5
+
6
+ import torch
7
+
8
+ import sglang.srt.sampling.penaltylib as penaltylib
9
+
10
+ if TYPE_CHECKING:
11
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class SamplingBatchInfo:
16
+ # Basic Info
17
+ vocab_size: int
18
+
19
+ # Batched sampling params
20
+ temperatures: torch.Tensor = None
21
+ top_ps: torch.Tensor = None
22
+ top_ks: torch.Tensor = None
23
+ min_ps: torch.Tensor = None
24
+
25
+ # Dispatch in CUDA graph
26
+ need_min_p_sampling: bool = False
27
+
28
+ # Bias Tensors
29
+ logit_bias: torch.Tensor = None
30
+ vocab_mask: torch.Tensor = None
31
+
32
+ # Penalizer
33
+ penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
34
+ linear_penalties: torch.Tensor = None
35
+ scaling_penalties: torch.Tensor = None
36
+
37
+ def has_bias(self):
38
+ return (
39
+ self.logit_bias is not None
40
+ or self.vocab_mask is not None
41
+ or self.linear_penalties is not None
42
+ or self.scaling_penalties is not None
43
+ )
44
+
45
+ @classmethod
46
+ def dummy_one(cls, max_bs: int, vocab_size: int):
47
+ ret = cls(vocab_size=vocab_size)
48
+ ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
49
+ ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
50
+ ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
51
+ ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
52
+ return ret
53
+
54
+ def __getitem__(self, key):
55
+ if isinstance(key, slice):
56
+ # NOTE: We do not use cuda graph when there is bias tensors
57
+ assert not self.has_bias()
58
+ return SamplingBatchInfo(
59
+ vocab_size=self.vocab_size,
60
+ temperatures=self.temperatures[key],
61
+ top_ps=self.top_ps[key],
62
+ top_ks=self.top_ks[key],
63
+ min_ps=self.min_ps[key],
64
+ need_min_p_sampling=self.need_min_p_sampling,
65
+ )
66
+ else:
67
+ raise NotImplementedError
68
+
69
+ def inplace_assign(self, bs: int, other: SamplingBatchInfo):
70
+ # NOTE: We do not use cuda graph when there is bias tensors
71
+ assert not self.has_bias()
72
+
73
+ self.vocab_size = other.vocab_size
74
+ self.need_min_p_sampling = other.need_min_p_sampling
75
+
76
+ self.temperatures[:bs] = other.temperatures
77
+ self.top_ps[:bs] = other.top_ps
78
+ self.top_ks[:bs] = other.top_ks
79
+ self.min_ps[:bs] = other.min_ps
80
+
81
+ @classmethod
82
+ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
83
+ device = "cuda"
84
+ reqs = batch.reqs
85
+ ret = cls(vocab_size=vocab_size)
86
+
87
+ ret.temperatures = torch.tensor(
88
+ [r.sampling_params.temperature for r in reqs],
89
+ dtype=torch.float,
90
+ device=device,
91
+ ).view(-1, 1)
92
+ ret.top_ps = torch.tensor(
93
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
94
+ )
95
+ ret.top_ks = torch.tensor(
96
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
97
+ )
98
+ ret.min_ps = torch.tensor(
99
+ [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
100
+ )
101
+ ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
102
+
103
+ # Each penalizers will do nothing if they evaluate themselves as not required by looking at
104
+ # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
105
+ # should not add hefty computation overhead other than simple checks.
106
+ #
107
+ # While we choose not to even create the class instances if they are not required, this
108
+ # could add additional complexity to the {ScheduleBatch} class, especially we need to
109
+ # handle {filter_batch()} and {merge()} cases as well.
110
+ ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
111
+ vocab_size=vocab_size,
112
+ batch=batch,
113
+ device=device,
114
+ Penalizers={
115
+ penaltylib.BatchedFrequencyPenalizer,
116
+ penaltylib.BatchedMinNewTokensPenalizer,
117
+ penaltylib.BatchedPresencePenalizer,
118
+ penaltylib.BatchedRepetitionPenalizer,
119
+ },
120
+ )
121
+
122
+ # Handle logit bias but only allocate when needed
123
+ ret.logit_bias = None
124
+
125
+ ret.update_regex_vocab_mask(batch)
126
+
127
+ return ret
128
+
129
+ def prepare_penalties(self):
130
+ self.scaling_penalties = None
131
+ self.linear_penalties = None
132
+
133
+ for penalizer in self.penalizer_orchestrator.penalizers.values():
134
+ if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
135
+ if penalizer.is_prepared():
136
+ self.scaling_penalties = penalizer.cumulated_repetition_penalties
137
+ else:
138
+ if penalizer.is_prepared():
139
+ if self.linear_penalties is None:
140
+ bs = self.penalizer_orchestrator.batch.batch_size()
141
+ self.linear_penalties = torch.zeros(
142
+ (bs, self.vocab_size),
143
+ dtype=torch.float32,
144
+ device="cuda",
145
+ )
146
+ self.linear_penalties = penalizer.apply(self.linear_penalties)
147
+
148
+ def update_regex_vocab_mask(self, batch: ScheduleBatch):
149
+ bs, reqs = batch.batch_size(), batch.reqs
150
+ device = "cuda"
151
+ has_regex = any(req.regex_fsm is not None for req in reqs)
152
+
153
+ # Reset the vocab mask
154
+ self.vocab_mask = None
155
+
156
+ if has_regex:
157
+ for i, req in enumerate(reqs):
158
+ if req.regex_fsm is not None:
159
+ if self.vocab_mask is None:
160
+ self.vocab_mask = torch.zeros(
161
+ bs, self.vocab_size, dtype=torch.bool, device=device
162
+ )
163
+ self.vocab_mask[i][
164
+ req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
165
+ ] = 1
166
+
167
+ def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
168
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
169
+
170
+ for item in [
171
+ "temperatures",
172
+ "top_ps",
173
+ "top_ks",
174
+ "min_ps",
175
+ "logit_bias",
176
+ ]:
177
+ self_val = getattr(self, item, None)
178
+ if self_val is not None: # logit_bias can be None
179
+ setattr(self, item, self_val[new_indices])
180
+
181
+ def merge(self, other: "SamplingBatchInfo"):
182
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
183
+
184
+ for item in [
185
+ "temperatures",
186
+ "top_ps",
187
+ "top_ks",
188
+ "min_ps",
189
+ ]:
190
+ self_val = getattr(self, item, None)
191
+ other_val = getattr(other, item, None)
192
+ setattr(self, item, torch.concat([self_val, other_val]))
193
+
194
+ # logit_bias can be None
195
+ if self.logit_bias is not None or other.logit_bias is not None:
196
+ vocab_size = (
197
+ self.logit_bias.shape[1]
198
+ if self.logit_bias is not None
199
+ else other.logit_bias.shape[1]
200
+ )
201
+ if self.logit_bias is None:
202
+ self.logit_bias = torch.zeros(
203
+ (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
204
+ )
205
+ if other.logit_bias is None:
206
+ other.logit_bias = torch.zeros(
207
+ (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
208
+ )
209
+ self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
@@ -30,19 +30,20 @@ class SamplingParams:
30
30
  temperature: float = 1.0,
31
31
  top_p: float = 1.0,
32
32
  top_k: int = -1,
33
+ min_p: float = 0.0,
33
34
  frequency_penalty: float = 0.0,
34
35
  presence_penalty: float = 0.0,
35
36
  repetition_penalty: float = 1.0,
36
37
  ignore_eos: bool = False,
37
38
  skip_special_tokens: bool = True,
38
39
  spaces_between_special_tokens: bool = True,
39
- dtype: Optional[str] = None,
40
40
  regex: Optional[str] = None,
41
41
  n: int = 1,
42
42
  ) -> None:
43
43
  self.temperature = temperature
44
44
  self.top_p = top_p
45
45
  self.top_k = top_k
46
+ self.min_p = min_p
46
47
  self.frequency_penalty = frequency_penalty
47
48
  self.presence_penalty = presence_penalty
48
49
  self.repetition_penalty = repetition_penalty
@@ -53,7 +54,6 @@ class SamplingParams:
53
54
  self.ignore_eos = ignore_eos
54
55
  self.skip_special_tokens = skip_special_tokens
55
56
  self.spaces_between_special_tokens = spaces_between_special_tokens
56
- self.dtype = dtype
57
57
  self.regex = regex
58
58
  self.n = n
59
59
 
@@ -63,8 +63,6 @@ class SamplingParams:
63
63
  self.top_k = 1
64
64
  if self.top_k == -1:
65
65
  self.top_k = 1 << 30 # whole vocabulary
66
- if self.dtype == "int":
67
- self.stop_strs = [" ", "\n"]
68
66
 
69
67
  def verify(self):
70
68
  if self.temperature < 0.0:
@@ -73,6 +71,8 @@ class SamplingParams:
73
71
  )
74
72
  if not 0.0 < self.top_p <= 1.0:
75
73
  raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
74
+ if not 0.0 <= self.min_p <= 1.0:
75
+ raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
76
76
  if self.top_k < -1 or self.top_k == 0:
77
77
  raise ValueError(
78
78
  f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
@@ -127,3 +127,17 @@ class SamplingParams:
127
127
  else:
128
128
  stop_str_max_len = max(stop_str_max_len, len(stop_str))
129
129
  self.stop_str_max_len = stop_str_max_len
130
+
131
+ def to_srt_kwargs(self):
132
+ return {
133
+ "max_new_tokens": self.max_new_tokens,
134
+ "stop": self.stop_strs,
135
+ "stop_token_ids": list(self.stop_token_ids),
136
+ "temperature": self.temperature,
137
+ "top_p": self.top_p,
138
+ "top_k": self.top_k,
139
+ "frequency_penalty": self.frequency_penalty,
140
+ "presence_penalty": self.presence_penalty,
141
+ "ignore_eos": self.ignore_eos,
142
+ "regex": self.regex,
143
+ }
sglang/srt/server.py CHANGED
@@ -24,7 +24,6 @@ import json
24
24
  import logging
25
25
  import multiprocessing as mp
26
26
  import os
27
- import sys
28
27
  import threading
29
28
  import time
30
29
  from http import HTTPStatus
@@ -34,7 +33,6 @@ from typing import Dict, List, Optional, Union
34
33
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
35
34
 
36
35
  import aiohttp
37
- import psutil
38
36
  import requests
39
37
  import uvicorn
40
38
  import uvloop
@@ -52,7 +50,11 @@ from sglang.srt.managers.controller_single import (
52
50
  start_controller_process as start_controller_process_single,
53
51
  )
54
52
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
55
- from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
53
+ from sglang.srt.managers.io_struct import (
54
+ EmbeddingReqInput,
55
+ GenerateReqInput,
56
+ UpdateWeightReqInput,
57
+ )
56
58
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
57
59
  from sglang.srt.openai_api.adapter import (
58
60
  load_chat_template_for_openai_api,
@@ -72,6 +74,7 @@ from sglang.srt.utils import (
72
74
  add_api_key_middleware,
73
75
  allocate_init_ports,
74
76
  assert_pkg_version,
77
+ configure_logger,
75
78
  enable_show_time_cost,
76
79
  kill_child_process,
77
80
  maybe_set_triton_cache_manager,
@@ -92,10 +95,25 @@ tokenizer_manager = None
92
95
 
93
96
  @app.get("/health")
94
97
  async def health() -> Response:
95
- """Health check."""
98
+ """Check the health of the http server."""
96
99
  return Response(status_code=200)
97
100
 
98
101
 
102
+ @app.get("/health_generate")
103
+ async def health_generate(request: Request) -> Response:
104
+ """Check the health of the inference server by generating one token."""
105
+ gri = GenerateReqInput(
106
+ text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
107
+ )
108
+ try:
109
+ async for _ in tokenizer_manager.generate_request(gri, request):
110
+ break
111
+ return Response(status_code=200)
112
+ except Exception as e:
113
+ logger.exception(e)
114
+ return Response(status_code=503)
115
+
116
+
99
117
  @app.get("/get_model_info")
100
118
  async def get_model_info():
101
119
  result = {
@@ -120,6 +138,23 @@ async def flush_cache():
120
138
  )
121
139
 
122
140
 
141
+ @app.post("/update_weights")
142
+ async def update_weights(obj: UpdateWeightReqInput, request: Request):
143
+
144
+ success, message = await tokenizer_manager.update_weights(obj, request)
145
+ content = {"message": message, "success": str(success)}
146
+ if success:
147
+ return JSONResponse(
148
+ content,
149
+ status_code=HTTPStatus.OK,
150
+ )
151
+ else:
152
+ return JSONResponse(
153
+ content,
154
+ status_code=HTTPStatus.BAD_REQUEST,
155
+ )
156
+
157
+
123
158
  async def generate_request(obj: GenerateReqInput, request: Request):
124
159
  """Handle a generate request."""
125
160
  if obj.stream:
@@ -236,15 +271,12 @@ def launch_server(
236
271
  """Launch an HTTP server."""
237
272
  global tokenizer_manager
238
273
 
239
- logging.basicConfig(
240
- level=getattr(logging, server_args.log_level.upper()),
241
- format="%(message)s",
242
- )
274
+ configure_logger(server_args)
243
275
 
244
276
  server_args.check_server_args()
245
277
  _set_envs_and_config(server_args)
246
278
 
247
- # Allocate ports
279
+ # Allocate ports for inter-process communications
248
280
  server_args.port, server_args.additional_ports = allocate_init_ports(
249
281
  server_args.port,
250
282
  server_args.additional_ports,
@@ -264,30 +296,34 @@ def launch_server(
264
296
  server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
265
297
 
266
298
  # Launch processes for multi-node tensor parallelism
267
- if server_args.nnodes > 1:
268
- if server_args.node_rank != 0:
269
- tp_size_local = server_args.tp_size // server_args.nnodes
270
- gpu_ids = [
271
- i for _ in range(server_args.nnodes) for i in range(tp_size_local)
272
- ]
273
- tp_rank_range = list(
274
- range(
275
- server_args.node_rank * tp_size_local,
276
- (server_args.node_rank + 1) * tp_size_local,
277
- )
278
- )
279
- procs = launch_tp_servers(
280
- gpu_ids,
281
- tp_rank_range,
282
- server_args,
283
- ports[3],
284
- model_overide_args,
299
+ if server_args.nnodes > 1 and server_args.node_rank != 0:
300
+ tp_size_local = server_args.tp_size // server_args.nnodes
301
+ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
302
+ tp_rank_range = list(
303
+ range(
304
+ server_args.node_rank * tp_size_local,
305
+ (server_args.node_rank + 1) * tp_size_local,
285
306
  )
286
- while True:
287
- pass
307
+ )
308
+ procs = launch_tp_servers(
309
+ gpu_ids,
310
+ tp_rank_range,
311
+ server_args,
312
+ ports[3],
313
+ model_overide_args,
314
+ )
315
+
316
+ try:
317
+ for p in procs:
318
+ p.join()
319
+ finally:
320
+ kill_child_process(os.getpid(), including_parent=False)
321
+ return
288
322
 
289
323
  # Launch processes
290
324
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
325
+ if server_args.chat_template:
326
+ load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
291
327
  pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
292
328
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
293
329
 
@@ -295,11 +331,13 @@ def launch_server(
295
331
  start_process = start_controller_process_single
296
332
  else:
297
333
  start_process = start_controller_process_multi
334
+
298
335
  proc_controller = mp.Process(
299
336
  target=start_process,
300
337
  args=(server_args, port_args, pipe_controller_writer, model_overide_args),
301
338
  )
302
339
  proc_controller.start()
340
+
303
341
  proc_detoken = mp.Process(
304
342
  target=start_detokenizer_process,
305
343
  args=(
@@ -317,15 +355,11 @@ def launch_server(
317
355
  if controller_init_state != "init ok" or detoken_init_state != "init ok":
318
356
  proc_controller.kill()
319
357
  proc_detoken.kill()
320
- print(
321
- f"Initialization failed. controller_init_state: {controller_init_state}",
322
- flush=True,
323
- )
324
- print(
325
- f"Initialization failed. detoken_init_state: {detoken_init_state}",
326
- flush=True,
358
+ raise RuntimeError(
359
+ "Initialization failed. "
360
+ f"controller_init_state: {controller_init_state}, "
361
+ f"detoken_init_state: {detoken_init_state}"
327
362
  )
328
- sys.exit(1)
329
363
  assert proc_controller.is_alive() and proc_detoken.is_alive()
330
364
 
331
365
  # Add api key authorization
@@ -334,12 +368,12 @@ def launch_server(
334
368
 
335
369
  # Send a warmup request
336
370
  t = threading.Thread(
337
- target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
371
+ target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
338
372
  )
339
373
  t.start()
340
374
 
341
- # Listen for requests
342
375
  try:
376
+ # Listen for requests
343
377
  uvicorn.run(
344
378
  app,
345
379
  host=server_args.host,
@@ -358,6 +392,7 @@ def _set_envs_and_config(server_args: ServerArgs):
358
392
  os.environ["NCCL_CUMEM_ENABLE"] = "0"
359
393
  os.environ["NCCL_NVLS_ENABLE"] = "0"
360
394
  os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
395
+ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
361
396
 
362
397
  # Set ulimit
363
398
  set_ulimit()
@@ -375,23 +410,18 @@ def _set_envs_and_config(server_args: ServerArgs):
375
410
  # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
376
411
  maybe_set_triton_cache_manager()
377
412
 
378
- # Set global chat template
379
- if server_args.chat_template:
380
- # TODO: replace this with huggingface transformers template
381
- load_chat_template_for_openai_api(server_args.chat_template)
382
-
383
413
  # Check flashinfer version
384
414
  if not server_args.disable_flashinfer:
385
415
  assert_pkg_version(
386
416
  "flashinfer",
387
- "0.1.4",
417
+ "0.1.5",
388
418
  "Please uninstall the old version and "
389
419
  "reinstall the latest version by following the instructions "
390
420
  "at https://docs.flashinfer.ai/installation.html.",
391
421
  )
392
422
 
393
423
 
394
- def _wait_and_warmup(server_args, pipe_finish_writer):
424
+ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
395
425
  headers = {}
396
426
  url = server_args.url()
397
427
  if server_args.api_key:
@@ -414,8 +444,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
414
444
  if not success:
415
445
  if pipe_finish_writer is not None:
416
446
  pipe_finish_writer.send(last_traceback)
417
- print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
418
- sys.exit(1)
447
+ logger.error(f"Initialization failed. warmup error: {last_traceback}")
448
+ kill_child_process(pid, including_parent=False)
449
+ return
419
450
 
420
451
  # Send a warmup request
421
452
  request_name = "/generate" if model_info["is_generation"] else "/encode"
@@ -440,21 +471,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
440
471
  timeout=600,
441
472
  )
442
473
  assert res.status_code == 200, f"{res}"
443
- except Exception as e:
474
+ except Exception:
444
475
  last_traceback = get_exception_traceback()
445
476
  if pipe_finish_writer is not None:
446
477
  pipe_finish_writer.send(last_traceback)
447
- print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
448
- sys.exit(1)
449
-
450
- # Print warnings here
451
- if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
452
- logger.warning(
453
- "You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
454
- "This combination is an experimental feature and we noticed it can lead to "
455
- "wrong generation results. If you want to use chunked prefill, it is recommended "
456
- "not using `--disable-radix-cache`."
457
- )
478
+ logger.error(f"Initialization failed. warmup error: {last_traceback}")
479
+ kill_child_process(pid, including_parent=False)
480
+ return
458
481
 
459
482
  logger.info("The server is fired up and ready to roll!")
460
483
  if pipe_finish_writer is not None:
@@ -492,6 +515,7 @@ class Runtime:
492
515
 
493
516
  self.pid = None
494
517
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
518
+
495
519
  proc = mp.Process(
496
520
  target=launch_server,
497
521
  args=(self.server_args, model_overide_args, pipe_writer),
@@ -533,11 +557,18 @@ class Runtime:
533
557
  prompt: str,
534
558
  sampling_params: Optional[Dict] = None,
535
559
  ):
536
- json_data = {
537
- "text": prompt,
538
- "sampling_params": sampling_params,
539
- "stream": True,
540
- }
560
+ if self.server_args.skip_tokenizer_init:
561
+ json_data = {
562
+ "input_ids": prompt,
563
+ "sampling_params": sampling_params,
564
+ "stream": True,
565
+ }
566
+ else:
567
+ json_data = {
568
+ "text": prompt,
569
+ "sampling_params": sampling_params,
570
+ "stream": True,
571
+ }
541
572
  pos = 0
542
573
 
543
574
  timeout = aiohttp.ClientTimeout(total=3 * 3600)
@@ -549,24 +580,29 @@ class Runtime:
549
580
  if chunk == "data: [DONE]\n\n":
550
581
  break
551
582
  data = json.loads(chunk[5:].strip("\n"))
552
- cur = data["text"][pos:]
553
- if cur:
554
- yield cur
555
- pos += len(cur)
583
+ if hasattr(data, "text"):
584
+ cur = data["text"][pos:]
585
+ if cur:
586
+ yield cur
587
+ pos += len(cur)
588
+ else:
589
+ yield data
556
590
 
557
591
  add_request = async_generate
558
592
 
559
593
  def generate(
560
594
  self,
561
- prompt: str,
595
+ prompt: Union[str, List[str]],
562
596
  sampling_params: Optional[Dict] = None,
563
597
  return_logprob: Optional[Union[List[bool], bool]] = False,
598
+ logprob_start_len: Optional[Union[List[int], int]] = None,
564
599
  top_logprobs_num: Optional[Union[List[int], int]] = None,
565
600
  ):
566
601
  json_data = {
567
602
  "text": prompt,
568
603
  "sampling_params": sampling_params,
569
604
  "return_logprob": return_logprob,
605
+ "logprob_start_len": logprob_start_len,
570
606
  "top_logprobs_num": top_logprobs_num,
571
607
  }
572
608
  response = requests.post(
@@ -577,7 +613,7 @@ class Runtime:
577
613
 
578
614
  def encode(
579
615
  self,
580
- prompt: str,
616
+ prompt: Union[str, List[str]],
581
617
  ):
582
618
  json_data = {
583
619
  "text": prompt,