sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- from typing import TYPE_CHECKING, List
4
+ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
 
@@ -20,6 +20,9 @@ class SamplingBatchInfo:
20
20
  top_ks: torch.Tensor
21
21
  min_ps: torch.Tensor
22
22
 
23
+ # All requests use greedy sampling
24
+ is_all_greedy: bool
25
+
23
26
  # Dispatch in CUDA graph
24
27
  need_min_p_sampling: bool
25
28
 
@@ -33,27 +36,39 @@ class SamplingBatchInfo:
33
36
  regex_fsm_states: List[int] = None
34
37
 
35
38
  # Penalizer
36
- penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
37
- linear_penalties: torch.Tensor = None
38
- scaling_penalties: torch.Tensor = None
39
+ penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
40
+ linear_penalties: Optional[torch.Tensor] = None
41
+ scaling_penalties: Optional[torch.Tensor] = None
42
+
43
+ # Device
44
+ device: str = "cuda"
39
45
 
40
46
  @classmethod
41
- def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
47
+ def from_schedule_batch(
48
+ cls,
49
+ batch: ScheduleBatch,
50
+ vocab_size: int,
51
+ disable_penalizer: bool,
52
+ ):
42
53
  reqs = batch.reqs
43
- with batch.input_ids.device:
44
- temperatures = torch.tensor(
54
+ device = batch.input_ids.device
55
+ temperatures = (
56
+ torch.tensor(
45
57
  [r.sampling_params.temperature for r in reqs],
46
58
  dtype=torch.float,
47
- ).view(-1, 1)
48
- top_ps = torch.tensor(
49
- [r.sampling_params.top_p for r in reqs], dtype=torch.float
50
- )
51
- top_ks = torch.tensor(
52
- [r.sampling_params.top_k for r in reqs], dtype=torch.int
53
- )
54
- min_ps = torch.tensor(
55
- [r.sampling_params.min_p for r in reqs], dtype=torch.float
56
59
  )
60
+ .view(-1, 1)
61
+ .to(device, non_blocking=True)
62
+ )
63
+ top_ps = torch.tensor(
64
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float
65
+ ).to(device, non_blocking=True)
66
+ top_ks = torch.tensor(
67
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int32
68
+ ).to(device, non_blocking=True)
69
+ min_ps = torch.tensor(
70
+ [r.sampling_params.min_p for r in reqs], dtype=torch.float
71
+ ).to(device, non_blocking=True)
57
72
 
58
73
  ret = cls(
59
74
  temperatures=temperatures,
@@ -61,7 +76,9 @@ class SamplingBatchInfo:
61
76
  top_ks=top_ks,
62
77
  min_ps=min_ps,
63
78
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
79
+ is_all_greedy=top_ks.max().item() <= 1,
64
80
  vocab_size=vocab_size,
81
+ device=batch.input_ids.device,
65
82
  )
66
83
  # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
67
84
 
@@ -71,18 +88,21 @@ class SamplingBatchInfo:
71
88
  #
72
89
  # While we choose not to even create the class instances if they are not required, this
73
90
  # could add additional complexity to the {ScheduleBatch} class, especially we need to
74
- # handle {filter_batch()} and {merge()} cases as well.
75
- ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
76
- vocab_size=vocab_size,
77
- batch=batch,
78
- device="cuda",
79
- Penalizers={
80
- penaltylib.BatchedFrequencyPenalizer,
81
- penaltylib.BatchedMinNewTokensPenalizer,
82
- penaltylib.BatchedPresencePenalizer,
83
- penaltylib.BatchedRepetitionPenalizer,
84
- },
85
- )
91
+ # handle {filter_batch()} and {merge_batch()} cases as well.
92
+ if disable_penalizer:
93
+ ret.penalizer_orchestrator = None
94
+ else:
95
+ ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
96
+ vocab_size=vocab_size,
97
+ batch=batch,
98
+ device=batch.input_ids.device,
99
+ Penalizers={
100
+ penaltylib.BatchedFrequencyPenalizer,
101
+ penaltylib.BatchedMinNewTokensPenalizer,
102
+ penaltylib.BatchedPresencePenalizer,
103
+ penaltylib.BatchedRepetitionPenalizer,
104
+ },
105
+ )
86
106
 
87
107
  # Handle logit bias but only allocate when needed
88
108
  ret.logit_bias = None
@@ -93,43 +113,50 @@ class SamplingBatchInfo:
93
113
  return len(self.temperatures)
94
114
 
95
115
  def update_penalties(self):
116
+ if not self.penalizer_orchestrator:
117
+ return
118
+
96
119
  self.scaling_penalties = None
97
120
  self.linear_penalties = None
98
121
 
99
122
  for penalizer in self.penalizer_orchestrator.penalizers.values():
123
+ if not penalizer.is_prepared():
124
+ continue
125
+
100
126
  if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
101
- if penalizer.is_prepared():
102
- self.scaling_penalties = penalizer.cumulated_repetition_penalties
127
+ self.scaling_penalties = penalizer.cumulated_repetition_penalties
103
128
  else:
104
- if penalizer.is_prepared():
105
- if self.linear_penalties is None:
106
- bs = self.penalizer_orchestrator.batch.batch_size()
107
- self.linear_penalties = torch.zeros(
108
- (bs, self.vocab_size),
109
- dtype=torch.float32,
110
- device="cuda",
111
- )
112
- self.linear_penalties = penalizer.apply(self.linear_penalties)
129
+ if self.linear_penalties is None:
130
+ bs = self.penalizer_orchestrator.batch.batch_size()
131
+ self.linear_penalties = torch.zeros(
132
+ (bs, self.vocab_size),
133
+ dtype=torch.float32,
134
+ device=self.device,
135
+ )
136
+ self.linear_penalties = penalizer.apply(self.linear_penalties)
113
137
 
114
138
  def update_regex_vocab_mask(self):
115
139
  has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
116
-
117
- # Reset the vocab mask
118
- self.vocab_mask = None
119
-
120
- if has_regex:
121
- self.vocab_mask = torch.zeros(
122
- len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
123
- )
124
- for i, regex_fsm in enumerate(self.regex_fsms):
125
- if regex_fsm is not None:
126
- self.vocab_mask[i].fill_(1)
127
- self.vocab_mask[i][
128
- regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
129
- ] = 0
140
+ if not has_regex:
141
+ self.vocab_mask = None
142
+ return
143
+
144
+ self.vocab_mask = torch.zeros(
145
+ len(self.temperatures),
146
+ self.vocab_size,
147
+ dtype=torch.bool,
148
+ device=self.device,
149
+ )
150
+ for i, regex_fsm in enumerate(self.regex_fsms):
151
+ if regex_fsm is not None:
152
+ self.vocab_mask[i].fill_(1)
153
+ self.vocab_mask[i][
154
+ regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
155
+ ] = 0
130
156
 
131
157
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
132
- self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
158
+ if self.penalizer_orchestrator:
159
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
133
160
 
134
161
  for item in [
135
162
  "temperatures",
@@ -144,7 +171,12 @@ class SamplingBatchInfo:
144
171
 
145
172
  @staticmethod
146
173
  def merge_bias_tensor(
147
- lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
174
+ lhs: torch.Tensor,
175
+ rhs: torch.Tensor,
176
+ bs1: int,
177
+ bs2: int,
178
+ device: str,
179
+ default: int = 0,
148
180
  ):
149
181
  # bias tensor can be None
150
182
  if lhs is not None or rhs is not None:
@@ -155,15 +187,16 @@ class SamplingBatchInfo:
155
187
  shape, dtype = rhs.shape[1:], rhs.dtype
156
188
  with torch.dtype(dtype):
157
189
  if lhs is None:
158
- lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
190
+ lhs = torch.empty((bs1, *shape), device=device).fill_(default)
159
191
  if rhs is None:
160
- rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
192
+ rhs = torch.empty((bs2, *shape), device=device).fill_(default)
161
193
  return torch.cat([lhs, rhs])
162
194
 
163
195
  return None
164
196
 
165
197
  def merge_batch(self, other: "SamplingBatchInfo"):
166
- self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
198
+ if self.penalizer_orchestrator:
199
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
167
200
 
168
201
  for item in [
169
202
  "temperatures",
@@ -175,6 +208,19 @@ class SamplingBatchInfo:
175
208
  other_val = getattr(other, item, None)
176
209
  setattr(self, item, torch.concat([self_val, other_val]))
177
210
 
211
+ self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
178
212
  self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
179
- self.logit_bias, other.logit_bias, len(self), len(other)
213
+ self.logit_bias, other.logit_bias, len(self), len(other), self.device
214
+ )
215
+
216
+ def copy(self):
217
+ return SamplingBatchInfo(
218
+ temperatures=self.temperatures,
219
+ top_ps=self.top_ps,
220
+ top_ks=self.top_ks,
221
+ min_ps=self.min_ps,
222
+ is_all_greedy=self.is_all_greedy,
223
+ need_min_p_sampling=self.need_min_p_sampling,
224
+ vocab_size=self.vocab_size,
225
+ device=self.device,
180
226
  )
@@ -40,6 +40,7 @@ class SamplingParams:
40
40
  regex: Optional[str] = None,
41
41
  n: int = 1,
42
42
  json_schema: Optional[str] = None,
43
+ no_stop_trim: bool = False,
43
44
  ) -> None:
44
45
  self.temperature = temperature
45
46
  self.top_p = top_p
@@ -60,6 +61,7 @@ class SamplingParams:
60
61
  self.regex = regex
61
62
  self.n = n
62
63
  self.json_schema = json_schema
64
+ self.no_stop_trim = no_stop_trim
63
65
 
64
66
  # Process some special cases
65
67
  if self.temperature < _SAMPLING_EPS:
sglang/srt/server.py CHANGED
@@ -25,11 +25,12 @@ import json
25
25
  import logging
26
26
  import multiprocessing as mp
27
27
  import os
28
- import random
29
28
  import threading
30
29
  import time
31
30
  from http import HTTPStatus
32
- from typing import Dict, List, Optional, Union
31
+ from typing import AsyncIterator, Dict, List, Optional, Union
32
+
33
+ import orjson
33
34
 
34
35
  # Fix a bug of Python threading
35
36
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -40,10 +41,14 @@ import uvicorn
40
41
  import uvloop
41
42
  from fastapi import FastAPI, File, Form, Request, UploadFile
42
43
  from fastapi.middleware.cors import CORSMiddleware
43
- from fastapi.responses import JSONResponse, Response, StreamingResponse
44
+ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
45
+ from uvicorn.config import LOGGING_CONFIG
44
46
 
45
47
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
46
48
  from sglang.srt.hf_transformers_utils import get_tokenizer
49
+ from sglang.srt.managers.data_parallel_controller import (
50
+ run_data_parallel_controller_process,
51
+ )
47
52
  from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
48
53
  from sglang.srt.managers.io_struct import (
49
54
  EmbeddingReqInput,
@@ -145,18 +150,40 @@ async def flush_cache():
145
150
  )
146
151
 
147
152
 
153
+ @app.get("/start_profile")
154
+ @app.post("/start_profile")
155
+ async def start_profile():
156
+ """Start profiling."""
157
+ tokenizer_manager.start_profile()
158
+ return Response(
159
+ content="Start profiling.\n",
160
+ status_code=200,
161
+ )
162
+
163
+
164
+ @app.get("/stop_profile")
165
+ @app.post("/stop_profile")
166
+ async def stop_profile():
167
+ """Stop profiling."""
168
+ tokenizer_manager.stop_profile()
169
+ return Response(
170
+ content="Stop profiling. This will take some time.\n",
171
+ status_code=200,
172
+ )
173
+
174
+
148
175
  @app.post("/update_weights")
149
176
  async def update_weights(obj: UpdateWeightReqInput, request: Request):
150
177
  """Update the weights inplace without re-launching the server."""
151
178
  success, message = await tokenizer_manager.update_weights(obj, request)
152
179
  content = {"success": success, "message": message}
153
180
  if success:
154
- return JSONResponse(
181
+ return ORJSONResponse(
155
182
  content,
156
183
  status_code=HTTPStatus.OK,
157
184
  )
158
185
  else:
159
- return JSONResponse(
186
+ return ORJSONResponse(
160
187
  content,
161
188
  status_code=HTTPStatus.BAD_REQUEST,
162
189
  )
@@ -167,14 +194,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
167
194
  """Handle a generate request."""
168
195
  if obj.stream:
169
196
 
170
- async def stream_results():
197
+ async def stream_results() -> AsyncIterator[bytes]:
171
198
  try:
172
199
  async for out in tokenizer_manager.generate_request(obj, request):
173
- yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
200
+ yield b"data: " + orjson.dumps(
201
+ out, option=orjson.OPT_NON_STR_KEYS
202
+ ) + b"\n\n"
174
203
  except ValueError as e:
175
204
  out = {"error": {"message": str(e)}}
176
- yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
177
- yield "data: [DONE]\n\n"
205
+ yield b"data: " + orjson.dumps(
206
+ out, option=orjson.OPT_NON_STR_KEYS
207
+ ) + b"\n\n"
208
+ yield b"data: [DONE]\n\n"
178
209
 
179
210
  return StreamingResponse(
180
211
  stream_results(),
@@ -186,7 +217,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
186
217
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
187
218
  return ret
188
219
  except ValueError as e:
189
- return JSONResponse(
220
+ return ORJSONResponse(
190
221
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
191
222
  )
192
223
 
@@ -201,7 +232,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
201
232
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
202
233
  return ret
203
234
  except ValueError as e:
204
- return JSONResponse(
235
+ return ORJSONResponse(
205
236
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
206
237
  )
207
238
 
@@ -216,7 +247,7 @@ async def judge_request(obj: RewardReqInput, request: Request):
216
247
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
217
248
  return ret
218
249
  except ValueError as e:
219
- return JSONResponse(
250
+ return ORJSONResponse(
220
251
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
221
252
  )
222
253
 
@@ -235,13 +266,13 @@ async def openai_v1_chat_completions(raw_request: Request):
235
266
  return await v1_chat_completions(tokenizer_manager, raw_request)
236
267
 
237
268
 
238
- @app.post("/v1/embeddings")
269
+ @app.post("/v1/embeddings", response_class=ORJSONResponse)
239
270
  async def openai_v1_embeddings(raw_request: Request):
240
271
  response = await v1_embeddings(tokenizer_manager, raw_request)
241
272
  return response
242
273
 
243
274
 
244
- @app.get("/v1/models")
275
+ @app.get("/v1/models", response_class=ORJSONResponse)
245
276
  def available_models():
246
277
  """Show available models."""
247
278
  served_model_names = [tokenizer_manager.served_model_name]
@@ -315,30 +346,40 @@ def launch_engine(
315
346
  server_args.model_path, server_args.tokenizer_path
316
347
  )
317
348
 
318
- # Launch tensor parallel scheduler processes
319
- scheduler_procs = []
320
- scheduler_pipe_readers = []
321
- tp_size_per_node = server_args.tp_size // server_args.nnodes
322
- tp_rank_range = range(
323
- tp_size_per_node * server_args.node_rank,
324
- tp_size_per_node * (server_args.node_rank + 1),
325
- )
326
- for tp_rank in tp_rank_range:
349
+ if server_args.dp_size == 1:
350
+ # Launch tensor parallel scheduler processes
351
+ scheduler_procs = []
352
+ scheduler_pipe_readers = []
353
+ tp_size_per_node = server_args.tp_size // server_args.nnodes
354
+ tp_rank_range = range(
355
+ tp_size_per_node * server_args.node_rank,
356
+ tp_size_per_node * (server_args.node_rank + 1),
357
+ )
358
+ for tp_rank in tp_rank_range:
359
+ reader, writer = mp.Pipe(duplex=False)
360
+ gpu_id = tp_rank % tp_size_per_node
361
+ proc = mp.Process(
362
+ target=run_scheduler_process,
363
+ args=(server_args, port_args, gpu_id, tp_rank, None, writer),
364
+ )
365
+ proc.start()
366
+ scheduler_procs.append(proc)
367
+ scheduler_pipe_readers.append(reader)
368
+
369
+ if server_args.node_rank >= 1:
370
+ # For other nodes, they do not need to run tokenizer or detokenizer,
371
+ # so they can just wait here.
372
+ while True:
373
+ pass
374
+ else:
375
+ # Launch the data parallel controller
327
376
  reader, writer = mp.Pipe(duplex=False)
328
- gpu_id = tp_rank % tp_size_per_node
377
+ scheduler_pipe_readers = [reader]
329
378
  proc = mp.Process(
330
- target=run_scheduler_process,
331
- args=(server_args, port_args, gpu_id, tp_rank, writer),
379
+ target=run_data_parallel_controller_process,
380
+ args=(server_args, port_args, writer),
332
381
  )
333
382
  proc.start()
334
- scheduler_procs.append(proc)
335
- scheduler_pipe_readers.append(reader)
336
-
337
- if server_args.node_rank >= 1:
338
- # For other nodes, they do not need to run tokenizer or detokenizer,
339
- # so they can just wait here.
340
- while True:
341
- pass
342
383
 
343
384
  # Launch detokenizer process
344
385
  detoken_proc = mp.Process(
@@ -394,6 +435,14 @@ def launch_server(
394
435
 
395
436
  try:
396
437
  # Listen for HTTP requests
438
+ LOGGING_CONFIG["formatters"]["default"][
439
+ "fmt"
440
+ ] = "[%(asctime)s] %(levelprefix)s %(message)s"
441
+ LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
442
+ LOGGING_CONFIG["formatters"]["access"][
443
+ "fmt"
444
+ ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
445
+ LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
397
446
  uvicorn.run(
398
447
  app,
399
448
  host=server_args.host,
@@ -412,7 +461,7 @@ def _set_envs_and_config(server_args: ServerArgs):
412
461
  os.environ["NCCL_CUMEM_ENABLE"] = "0"
413
462
  os.environ["NCCL_NVLS_ENABLE"] = "0"
414
463
  os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
415
- os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
464
+ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
416
465
 
417
466
  # Set ulimit
418
467
  set_ulimit()
@@ -493,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
493
542
  kill_child_process(pid, including_parent=False)
494
543
  return
495
544
 
545
+ # logger.info(f"{res.json()=}")
546
+
496
547
  logger.info("The server is fired up and ready to roll!")
497
548
  if pipe_finish_writer is not None:
498
549
  pipe_finish_writer.send("ready")
@@ -657,6 +708,10 @@ class Runtime:
657
708
  self.shutdown()
658
709
 
659
710
 
711
+ STREAM_END_SYMBOL = b"data: [DONE]"
712
+ STREAM_CHUNK_START_SYMBOL = b"data:"
713
+
714
+
660
715
  class Engine:
661
716
  """
662
717
  SRT Engine without an HTTP server layer.
@@ -681,7 +736,10 @@ class Engine:
681
736
  logprob_start_len: Optional[Union[List[int], int]] = None,
682
737
  top_logprobs_num: Optional[Union[List[int], int]] = None,
683
738
  lora_path: Optional[List[Optional[str]]] = None,
739
+ stream: bool = False,
684
740
  ):
741
+ # TODO (ByronHsu): refactor to reduce the duplicated code
742
+
685
743
  obj = GenerateReqInput(
686
744
  text=prompt,
687
745
  sampling_params=sampling_params,
@@ -689,13 +747,89 @@ class Engine:
689
747
  logprob_start_len=logprob_start_len,
690
748
  top_logprobs_num=top_logprobs_num,
691
749
  lora_path=lora_path,
750
+ stream=stream,
692
751
  )
693
752
 
694
753
  # get the current event loop
695
754
  loop = asyncio.get_event_loop()
696
- return loop.run_until_complete(generate_request(obj, None))
755
+ ret = loop.run_until_complete(generate_request(obj, None))
756
+
757
+ if stream is True:
758
+
759
+ def generator_wrapper():
760
+ offset = 0
761
+ loop = asyncio.get_event_loop()
762
+ generator = ret.body_iterator
763
+ while True:
764
+ chunk = loop.run_until_complete(generator.__anext__())
765
+
766
+ if chunk.startswith(STREAM_END_SYMBOL):
767
+ break
768
+ else:
769
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
770
+ data["text"] = data["text"][offset:]
771
+ offset += len(data["text"])
772
+ yield data
773
+
774
+ # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
775
+ # however, it allows to wrap the generator as a subfunction and return
776
+ return generator_wrapper()
777
+ else:
778
+ return ret
779
+
780
+ async def async_generate(
781
+ self,
782
+ prompt: Union[str, List[str]],
783
+ sampling_params: Optional[Dict] = None,
784
+ return_logprob: Optional[Union[List[bool], bool]] = False,
785
+ logprob_start_len: Optional[Union[List[int], int]] = None,
786
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
787
+ lora_path: Optional[List[Optional[str]]] = None,
788
+ stream: bool = False,
789
+ ):
790
+ obj = GenerateReqInput(
791
+ text=prompt,
792
+ sampling_params=sampling_params,
793
+ return_logprob=return_logprob,
794
+ logprob_start_len=logprob_start_len,
795
+ top_logprobs_num=top_logprobs_num,
796
+ lora_path=lora_path,
797
+ stream=stream,
798
+ )
799
+
800
+ ret = await generate_request(obj, None)
801
+
802
+ if stream is True:
803
+ generator = ret.body_iterator
804
+
805
+ async def generator_wrapper():
806
+
807
+ offset = 0
808
+
809
+ while True:
810
+ chunk = await generator.__anext__()
811
+
812
+ if chunk.startswith(STREAM_END_SYMBOL):
813
+ break
814
+ else:
815
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
816
+ data["text"] = data["text"][offset:]
817
+ offset += len(data["text"])
818
+ yield data
819
+
820
+ return generator_wrapper()
821
+ else:
822
+ return ret
697
823
 
698
824
  def shutdown(self):
699
825
  kill_child_process(os.getpid(), including_parent=False)
700
826
 
701
- # TODO (ByronHsu): encode and async generate
827
+ def get_tokenizer(self):
828
+ global tokenizer_manager
829
+
830
+ if tokenizer_manager is None:
831
+ raise ReferenceError("Tokenizer Manager is not initialized.")
832
+ else:
833
+ return tokenizer_manager.tokenizer
834
+
835
+ # TODO (ByronHsu): encode