sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List
6
6
  import torch
7
7
 
8
8
  import sglang.srt.sampling.penaltylib as penaltylib
9
+ from sglang.srt.constrained import RegexGuide
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -13,22 +14,24 @@ if TYPE_CHECKING:
13
14
 
14
15
  @dataclasses.dataclass
15
16
  class SamplingBatchInfo:
16
- # Basic Info
17
- vocab_size: int
18
-
19
17
  # Batched sampling params
20
- temperatures: torch.Tensor = None
21
- top_ps: torch.Tensor = None
22
- top_ks: torch.Tensor = None
23
- min_ps: torch.Tensor = None
18
+ temperatures: torch.Tensor
19
+ top_ps: torch.Tensor
20
+ top_ks: torch.Tensor
21
+ min_ps: torch.Tensor
24
22
 
25
23
  # Dispatch in CUDA graph
26
- need_min_p_sampling: bool = False
24
+ need_min_p_sampling: bool
27
25
 
28
26
  # Bias Tensors
27
+ vocab_size: int
29
28
  logit_bias: torch.Tensor = None
30
29
  vocab_mask: torch.Tensor = None
31
30
 
31
+ # FSM states
32
+ regex_fsms: List[RegexGuide] = None
33
+ regex_fsm_states: List[int] = None
34
+
32
35
  # Penalizer
33
36
  penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
34
37
  linear_penalties: torch.Tensor = None
@@ -37,24 +40,30 @@ class SamplingBatchInfo:
37
40
  @classmethod
38
41
  def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
39
42
  reqs = batch.reqs
40
- ret = cls(vocab_size=vocab_size)
41
-
42
- with torch.device("cuda"):
43
- ret.temperatures = torch.tensor(
43
+ with batch.input_ids.device:
44
+ temperatures = torch.tensor(
44
45
  [r.sampling_params.temperature for r in reqs],
45
46
  dtype=torch.float,
46
47
  ).view(-1, 1)
47
- ret.top_ps = torch.tensor(
48
+ top_ps = torch.tensor(
48
49
  [r.sampling_params.top_p for r in reqs], dtype=torch.float
49
50
  )
50
- ret.top_ks = torch.tensor(
51
+ top_ks = torch.tensor(
51
52
  [r.sampling_params.top_k for r in reqs], dtype=torch.int
52
53
  )
53
- ret.min_ps = torch.tensor(
54
+ min_ps = torch.tensor(
54
55
  [r.sampling_params.min_p for r in reqs], dtype=torch.float
55
56
  )
56
57
 
57
- ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
58
+ ret = cls(
59
+ temperatures=temperatures,
60
+ top_ps=top_ps,
61
+ top_ks=top_ks,
62
+ min_ps=min_ps,
63
+ need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
64
+ vocab_size=vocab_size,
65
+ )
66
+ # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
58
67
 
59
68
  # Each penalizers will do nothing if they evaluate themselves as not required by looking at
60
69
  # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
@@ -102,24 +111,24 @@ class SamplingBatchInfo:
102
111
  )
103
112
  self.linear_penalties = penalizer.apply(self.linear_penalties)
104
113
 
105
- def update_regex_vocab_mask(self, batch: ScheduleBatch):
106
- has_regex = any(req.regex_fsm is not None for req in batch.reqs)
114
+ def update_regex_vocab_mask(self):
115
+ has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
107
116
 
108
117
  # Reset the vocab mask
109
118
  self.vocab_mask = None
110
119
 
111
120
  if has_regex:
112
121
  self.vocab_mask = torch.zeros(
113
- batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
122
+ len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
114
123
  )
115
- for i, req in enumerate(batch.reqs):
116
- if req.regex_fsm is not None:
124
+ for i, regex_fsm in enumerate(self.regex_fsms):
125
+ if regex_fsm is not None:
117
126
  self.vocab_mask[i].fill_(1)
118
127
  self.vocab_mask[i][
119
- req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
128
+ regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
120
129
  ] = 0
121
130
 
122
- def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
131
+ def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
123
132
  self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
124
133
 
125
134
  for item in [
@@ -129,9 +138,9 @@ class SamplingBatchInfo:
129
138
  "min_ps",
130
139
  "logit_bias",
131
140
  ]:
132
- self_val = getattr(self, item, None)
133
- if self_val is not None: # logit_bias can be None
134
- setattr(self, item, self_val[new_indices])
141
+ value = getattr(self, item, None)
142
+ if value is not None: # logit_bias can be None
143
+ setattr(self, item, value[new_indices])
135
144
 
136
145
  @staticmethod
137
146
  def merge_bias_tensor(
@@ -153,7 +162,7 @@ class SamplingBatchInfo:
153
162
 
154
163
  return None
155
164
 
156
- def merge(self, other: "SamplingBatchInfo"):
165
+ def merge_batch(self, other: "SamplingBatchInfo"):
157
166
  self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
158
167
 
159
168
  for item in [
@@ -26,7 +26,7 @@ class SamplingParams:
26
26
  max_new_tokens: int = 128,
27
27
  min_new_tokens: int = 0,
28
28
  stop: Optional[Union[str, List[str]]] = None,
29
- stop_token_ids: Optional[List[int]] = [],
29
+ stop_token_ids: Optional[List[int]] = None,
30
30
  temperature: float = 1.0,
31
31
  top_p: float = 1.0,
32
32
  top_k: int = -1,
@@ -49,6 +49,8 @@ class SamplingParams:
49
49
  self.presence_penalty = presence_penalty
50
50
  self.repetition_penalty = repetition_penalty
51
51
  self.stop_strs = stop
52
+ if stop_token_ids is None:
53
+ stop_token_ids = []
52
54
  self.stop_token_ids = {*stop_token_ids}
53
55
  self.max_new_tokens = max_new_tokens
54
56
  self.min_new_tokens = min_new_tokens
sglang/srt/server.py CHANGED
@@ -19,11 +19,13 @@ SRT = SGLang Runtime.
19
19
  """
20
20
 
21
21
  import asyncio
22
+ import atexit
22
23
  import dataclasses
23
24
  import json
24
25
  import logging
25
26
  import multiprocessing as mp
26
27
  import os
28
+ import random
27
29
  import threading
28
30
  import time
29
31
  from http import HTTPStatus
@@ -41,21 +43,15 @@ from fastapi.middleware.cors import CORSMiddleware
41
43
  from fastapi.responses import JSONResponse, Response, StreamingResponse
42
44
 
43
45
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
44
- from sglang.srt.constrained import disable_cache
45
46
  from sglang.srt.hf_transformers_utils import get_tokenizer
46
- from sglang.srt.managers.controller_multi import (
47
- start_controller_process as start_controller_process_multi,
48
- )
49
- from sglang.srt.managers.controller_single import launch_tp_servers
50
- from sglang.srt.managers.controller_single import (
51
- start_controller_process as start_controller_process_single,
52
- )
53
- from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
47
+ from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
54
48
  from sglang.srt.managers.io_struct import (
55
49
  EmbeddingReqInput,
56
50
  GenerateReqInput,
51
+ RewardReqInput,
57
52
  UpdateWeightReqInput,
58
53
  )
54
+ from sglang.srt.managers.scheduler import run_scheduler_process
59
55
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
60
56
  from sglang.srt.openai_api.adapter import (
61
57
  load_chat_template_for_openai_api,
@@ -74,15 +70,12 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
74
70
  from sglang.srt.server_args import PortArgs, ServerArgs
75
71
  from sglang.srt.utils import (
76
72
  add_api_key_middleware,
77
- allocate_init_ports,
78
73
  assert_pkg_version,
79
74
  configure_logger,
80
- enable_show_time_cost,
81
- is_hip,
75
+ is_port_available,
82
76
  kill_child_process,
83
77
  maybe_set_triton_cache_manager,
84
- prepare_model,
85
- prepare_tokenizer,
78
+ prepare_model_and_tokenizer,
86
79
  set_ulimit,
87
80
  )
88
81
  from sglang.utils import get_exception_traceback
@@ -127,6 +120,7 @@ async def health_generate(request: Request) -> Response:
127
120
 
128
121
  @app.get("/get_model_info")
129
122
  async def get_model_info():
123
+ """Get the model information."""
130
124
  result = {
131
125
  "model_path": tokenizer_manager.model_path,
132
126
  "is_generation": tokenizer_manager.is_generation,
@@ -136,11 +130,13 @@ async def get_model_info():
136
130
 
137
131
  @app.get("/get_server_args")
138
132
  async def get_server_args():
133
+ """Get the server arguments."""
139
134
  return dataclasses.asdict(tokenizer_manager.server_args)
140
135
 
141
136
 
142
137
  @app.get("/flush_cache")
143
138
  async def flush_cache():
139
+ """Flush the radix cache."""
144
140
  tokenizer_manager.flush_cache()
145
141
  return Response(
146
142
  content="Cache flushed.\nPlease check backend logs for more details. "
@@ -151,7 +147,7 @@ async def flush_cache():
151
147
 
152
148
  @app.post("/update_weights")
153
149
  async def update_weights(obj: UpdateWeightReqInput, request: Request):
154
-
150
+ """Update the weights inplace without re-launching the server."""
155
151
  success, message = await tokenizer_manager.update_weights(obj, request)
156
152
  content = {"success": success, "message": message}
157
153
  if success:
@@ -166,6 +162,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
166
162
  )
167
163
 
168
164
 
165
+ # fastapi implicitly converts json in the request to obj (dataclass)
169
166
  async def generate_request(obj: GenerateReqInput, request: Request):
170
167
  """Handle a generate request."""
171
168
  if obj.stream:
@@ -213,6 +210,21 @@ app.post("/encode")(encode_request)
213
210
  app.put("/encode")(encode_request)
214
211
 
215
212
 
213
+ async def judge_request(obj: RewardReqInput, request: Request):
214
+ """Handle a reward model request."""
215
+ try:
216
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
217
+ return ret
218
+ except ValueError as e:
219
+ return JSONResponse(
220
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
221
+ )
222
+
223
+
224
+ app.post("/judge")(judge_request)
225
+ app.put("/judge")(judge_request)
226
+
227
+
216
228
  @app.post("/v1/completions")
217
229
  async def openai_v1_completions(raw_request: Request):
218
230
  return await v1_completions(tokenizer_manager, raw_request)
@@ -280,102 +292,95 @@ async def retrieve_file_content(file_id: str):
280
292
  return await v1_retrieve_file_content(file_id)
281
293
 
282
294
 
283
- def launch_server(
295
+ def launch_engine(
284
296
  server_args: ServerArgs,
285
- pipe_finish_writer: Optional[mp.connection.Connection] = None,
286
297
  ):
287
- """Launch an HTTP server."""
298
+ """
299
+ Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
300
+ """
301
+
288
302
  global tokenizer_manager
289
303
 
304
+ # Configure global environment
290
305
  configure_logger(server_args)
291
-
292
306
  server_args.check_server_args()
293
307
  _set_envs_and_config(server_args)
294
308
 
295
309
  # Allocate ports for inter-process communications
296
- server_args.port, server_args.additional_ports = allocate_init_ports(
297
- server_args.port,
298
- server_args.additional_ports,
299
- server_args.dp_size,
300
- )
301
- ports = server_args.additional_ports
302
- port_args = PortArgs(
303
- tokenizer_port=ports[0],
304
- controller_port=ports[1],
305
- detokenizer_port=ports[2],
306
- nccl_ports=ports[3:],
307
- )
310
+ port_args = PortArgs.init_new(server_args)
308
311
  logger.info(f"{server_args=}")
309
312
 
310
- # Use model from www.modelscope.cn, first download the model.
311
- server_args.model_path = prepare_model(server_args.model_path)
312
- server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
313
-
314
- # Launch processes for multi-node tensor parallelism
315
- if server_args.nnodes > 1 and server_args.node_rank != 0:
316
- tp_size_local = server_args.tp_size // server_args.nnodes
317
- gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
318
- tp_rank_range = list(
319
- range(
320
- server_args.node_rank * tp_size_local,
321
- (server_args.node_rank + 1) * tp_size_local,
322
- )
323
- )
324
- procs = launch_tp_servers(
325
- gpu_ids,
326
- tp_rank_range,
327
- server_args,
328
- ports[3],
329
- )
330
-
331
- try:
332
- for p in procs:
333
- p.join()
334
- finally:
335
- kill_child_process(os.getpid(), including_parent=False)
336
- return
337
-
338
- # Launch processes
339
- pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
313
+ # If using model from www.modelscope.cn, first download the model.
314
+ server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
315
+ server_args.model_path, server_args.tokenizer_path
316
+ )
340
317
 
341
- if server_args.dp_size == 1:
342
- start_controller_process = start_controller_process_single
343
- else:
344
- start_controller_process = start_controller_process_multi
345
- proc_controller = mp.Process(
346
- target=start_controller_process,
347
- args=(server_args, port_args, pipe_controller_writer),
318
+ # Launch tensor parallel scheduler processes
319
+ scheduler_procs = []
320
+ scheduler_pipe_readers = []
321
+ tp_size_per_node = server_args.tp_size // server_args.nnodes
322
+ tp_rank_range = range(
323
+ tp_size_per_node * server_args.node_rank,
324
+ tp_size_per_node * (server_args.node_rank + 1),
348
325
  )
349
- proc_controller.start()
326
+ for tp_rank in tp_rank_range:
327
+ reader, writer = mp.Pipe(duplex=False)
328
+ gpu_id = tp_rank % tp_size_per_node
329
+ proc = mp.Process(
330
+ target=run_scheduler_process,
331
+ args=(server_args, port_args, gpu_id, tp_rank, writer),
332
+ )
333
+ proc.start()
334
+ scheduler_procs.append(proc)
335
+ scheduler_pipe_readers.append(reader)
336
+
337
+ if server_args.node_rank >= 1:
338
+ # For other nodes, they do not need to run tokenizer or detokenizer,
339
+ # so they can just wait here.
340
+ while True:
341
+ pass
350
342
 
351
- pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
352
- proc_detoken = mp.Process(
353
- target=start_detokenizer_process,
343
+ # Launch detokenizer process
344
+ detoken_proc = mp.Process(
345
+ target=run_detokenizer_process,
354
346
  args=(
355
347
  server_args,
356
348
  port_args,
357
- pipe_detoken_writer,
358
349
  ),
359
350
  )
360
- proc_detoken.start()
351
+ detoken_proc.start()
361
352
 
353
+ # Launch tokenizer process
362
354
  tokenizer_manager = TokenizerManager(server_args, port_args)
363
355
  if server_args.chat_template:
364
356
  load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
365
357
 
366
- # Wait for the model to finish loading
367
- controller_init_state = pipe_controller_reader.recv()
368
- detoken_init_state = pipe_detoken_reader.recv()
369
-
370
- if controller_init_state != "init ok" or detoken_init_state != "init ok":
371
- proc_controller.kill()
372
- proc_detoken.kill()
373
- raise RuntimeError(
374
- "Initialization failed. "
375
- f"controller_init_state: {controller_init_state}, "
376
- f"detoken_init_state: {detoken_init_state}"
377
- )
378
- assert proc_controller.is_alive() and proc_detoken.is_alive()
358
+ # Wait for model to finish loading
359
+ for i in range(len(scheduler_pipe_readers)):
360
+ scheduler_pipe_readers[i].recv()
361
+
362
+
363
+ def launch_server(
364
+ server_args: ServerArgs,
365
+ pipe_finish_writer: Optional[mp.connection.Connection] = None,
366
+ ):
367
+ """
368
+ Launch SRT (SGLang Runtime) Server
369
+
370
+ The SRT server consists of an HTTP server and the SRT engine.
371
+
372
+ 1. HTTP server: A FastAPI server that routes requests to the engine.
373
+ 2. SRT engine:
374
+ 1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
375
+ 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
376
+ 3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
377
+
378
+ Note:
379
+ 1. The HTTP server and Tokenizer Manager both run in the main process.
380
+ 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
381
+ """
382
+
383
+ launch_engine(server_args=server_args)
379
384
 
380
385
  # Add api key authorization
381
386
  if server_args.api_key:
@@ -388,7 +393,7 @@ def launch_server(
388
393
  t.start()
389
394
 
390
395
  try:
391
- # Listen for requests
396
+ # Listen for HTTP requests
392
397
  uvicorn.run(
393
398
  app,
394
399
  host=server_args.host,
@@ -412,14 +417,6 @@ def _set_envs_and_config(server_args: ServerArgs):
412
417
  # Set ulimit
413
418
  set_ulimit()
414
419
 
415
- # Enable show time cost for debugging
416
- if server_args.show_time_cost:
417
- enable_show_time_cost()
418
-
419
- # Disable disk cache
420
- if server_args.disable_disk_cache:
421
- disable_cache()
422
-
423
420
  # Fix triton bugs
424
421
  if server_args.tp_size * server_args.dp_size > 1:
425
422
  # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
@@ -435,9 +432,7 @@ def _set_envs_and_config(server_args: ServerArgs):
435
432
  "at https://docs.flashinfer.ai/installation.html.",
436
433
  )
437
434
 
438
- if is_hip():
439
- # to figure out a better method of not using fork later
440
- mp.set_start_method("spawn", force=True)
435
+ mp.set_start_method("spawn", force=True)
441
436
 
442
437
 
443
438
  def _wait_and_warmup(server_args, pipe_finish_writer, pid):
@@ -467,7 +462,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
467
462
  return
468
463
 
469
464
  model_info = res.json()
470
-
471
465
  # Send a warmup request
472
466
  request_name = "/generate" if model_info["is_generation"] else "/encode"
473
467
  max_new_tokens = 8 if model_info["is_generation"] else 1
@@ -501,7 +495,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
501
495
 
502
496
  logger.info("The server is fired up and ready to roll!")
503
497
  if pipe_finish_writer is not None:
504
- pipe_finish_writer.send("init ok")
498
+ pipe_finish_writer.send("ready")
505
499
 
506
500
 
507
501
  class Runtime:
@@ -520,18 +514,20 @@ class Runtime:
520
514
  """See the arguments in server_args.py::ServerArgs"""
521
515
  self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
522
516
 
517
+ # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
518
+ atexit.register(self.shutdown)
519
+
523
520
  # Pre-allocate ports
524
- self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
525
- self.server_args.port,
526
- self.server_args.additional_ports,
527
- self.server_args.dp_size,
528
- )
521
+ for port in range(10000, 40000):
522
+ if is_port_available(port):
523
+ break
524
+ port += 1
525
+ self.server_args.port = port
529
526
 
530
527
  self.url = self.server_args.url()
531
- self.generate_url = (
532
- f"http://{self.server_args.host}:{self.server_args.port}/generate"
533
- )
528
+ self.generate_url = self.url + "/generate"
534
529
 
530
+ # NOTE: We store pid instead of proc to fix some issues during __delete__
535
531
  self.pid = None
536
532
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
537
533
 
@@ -548,7 +544,7 @@ class Runtime:
548
544
  except EOFError:
549
545
  init_state = ""
550
546
 
551
- if init_state != "init ok":
547
+ if init_state != "ready":
552
548
  self.shutdown()
553
549
  raise RuntimeError(
554
550
  "Initialization failed. Please see the error messages above."
@@ -599,7 +595,7 @@ class Runtime:
599
595
  if chunk == "data: [DONE]\n\n":
600
596
  break
601
597
  data = json.loads(chunk[5:].strip("\n"))
602
- if hasattr(data, "text"):
598
+ if "text" in data:
603
599
  cur = data["text"][pos:]
604
600
  if cur:
605
601
  yield cur
@@ -635,16 +631,71 @@ class Runtime:
635
631
 
636
632
  def encode(
637
633
  self,
638
- prompt: Union[str, List[str]],
634
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
639
635
  ):
640
- json_data = {
641
- "text": prompt,
642
- }
643
- response = requests.post(
644
- self.url + "/encode",
645
- json=json_data,
646
- )
636
+ if isinstance(prompt, str) or isinstance(prompt[0], str):
637
+ # embedding
638
+ json_data = {
639
+ "text": prompt,
640
+ }
641
+ response = requests.post(
642
+ self.url + "/encode",
643
+ json=json_data,
644
+ )
645
+ else:
646
+ # reward
647
+ json_data = {
648
+ "conv": prompt,
649
+ }
650
+ response = requests.post(
651
+ self.url + "/judge",
652
+ json=json_data,
653
+ )
647
654
  return json.dumps(response.json())
648
655
 
649
656
  def __del__(self):
650
657
  self.shutdown()
658
+
659
+
660
+ class Engine:
661
+ """
662
+ SRT Engine without an HTTP server layer.
663
+
664
+ This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
665
+ launching the HTTP server adds unnecessary complexity or overhead,
666
+ """
667
+
668
+ def __init__(self, *args, **kwargs):
669
+
670
+ # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
671
+ atexit.register(self.shutdown)
672
+
673
+ server_args = ServerArgs(*args, **kwargs)
674
+ launch_engine(server_args=server_args)
675
+
676
+ def generate(
677
+ self,
678
+ prompt: Union[str, List[str]],
679
+ sampling_params: Optional[Dict] = None,
680
+ return_logprob: Optional[Union[List[bool], bool]] = False,
681
+ logprob_start_len: Optional[Union[List[int], int]] = None,
682
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
683
+ lora_path: Optional[List[Optional[str]]] = None,
684
+ ):
685
+ obj = GenerateReqInput(
686
+ text=prompt,
687
+ sampling_params=sampling_params,
688
+ return_logprob=return_logprob,
689
+ logprob_start_len=logprob_start_len,
690
+ top_logprobs_num=top_logprobs_num,
691
+ lora_path=lora_path,
692
+ )
693
+
694
+ # get the current event loop
695
+ loop = asyncio.get_event_loop()
696
+ return loop.run_until_complete(generate_request(obj, None))
697
+
698
+ def shutdown(self):
699
+ kill_child_process(os.getpid(), including_parent=False)
700
+
701
+ # TODO (ByronHsu): encode and async generate