sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ # ==============================================================================
12
+
13
+ import copy
14
+ import uuid
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+
18
+ from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
19
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
20
+
21
+
22
+ class Session:
23
+ def __init__(self, capacity_of_str_len: int, session_id: str = None):
24
+ self.session_id = session_id if session_id is not None else uuid.uuid4().hex
25
+ self.capacity_of_str_len = capacity_of_str_len
26
+ self.reqs: List[Req] = []
27
+
28
+ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
29
+ if req.session_rid is not None:
30
+ while len(self.reqs) > 0:
31
+ if self.reqs[-1].rid == req.session_rid:
32
+ break
33
+ self.reqs = self.reqs[:-1]
34
+ else:
35
+ self.reqs = []
36
+ if len(self.reqs) > 0:
37
+ input_ids = (
38
+ self.reqs[-1].origin_input_ids
39
+ + self.reqs[-1].output_ids[
40
+ : self.reqs[-1].sampling_params.max_new_tokens
41
+ ]
42
+ + req.input_ids
43
+ )
44
+ else:
45
+ input_ids = req.input_ids
46
+ new_req = Req(
47
+ req.rid,
48
+ None,
49
+ input_ids,
50
+ req.sampling_params,
51
+ lora_path=req.lora_path,
52
+ session_id=self.session_id,
53
+ )
54
+ new_req.tokenizer = tokenizer
55
+ if req.session_rid is not None and len(self.reqs) == 0:
56
+ new_req.finished_reason = FINISH_ABORT(
57
+ f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
58
+ )
59
+ else:
60
+ self.reqs.append(new_req)
61
+ return new_req
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """TokenizerManager is a process that tokenizes the text."""
17
15
 
18
16
  import asyncio
@@ -23,6 +21,7 @@ import os
23
21
  import signal
24
22
  import sys
25
23
  import time
24
+ import uuid
26
25
  from typing import Dict, List, Optional, Tuple, Union
27
26
 
28
27
  import fastapi
@@ -42,11 +41,14 @@ from sglang.srt.managers.io_struct import (
42
41
  BatchEmbeddingOut,
43
42
  BatchStrOut,
44
43
  BatchTokenIDOut,
44
+ CloseSessionReqInput,
45
45
  EmbeddingReqInput,
46
46
  FlushCacheReq,
47
47
  GenerateReqInput,
48
48
  GetMemPoolSizeReq,
49
49
  GetMemPoolSizeReqOutput,
50
+ OpenSessionReqInput,
51
+ OpenSessionReqOutput,
50
52
  ProfileReq,
51
53
  TokenizedEmbeddingReqInput,
52
54
  TokenizedGenerateReqInput,
@@ -146,6 +148,9 @@ class TokenizerManager:
146
148
  self.model_update_lock = asyncio.Lock()
147
149
  self.model_update_result = None
148
150
 
151
+ # For session info
152
+ self.session_futures = {} # session_id -> asyncio event
153
+
149
154
  # Others
150
155
  self.gracefully_exit = False
151
156
 
@@ -196,8 +201,18 @@ class TokenizerManager:
196
201
  ):
197
202
  """Tokenize one request."""
198
203
  # Tokenize
204
+ input_embeds = None
199
205
  input_text = obj.text
200
- if obj.input_ids is None:
206
+ if obj.input_embeds is not None:
207
+ if not self.server_args.disable_radix_cache:
208
+ raise ValueError(
209
+ "input_embeds is provided while disable_radix_cache is False. "
210
+ "Please add `--disable-radix-cach` when you launch the server "
211
+ "if you want to use input_embeds as inputs."
212
+ )
213
+ input_embeds = obj.input_embeds
214
+ input_ids = obj.input_ids
215
+ elif obj.input_ids is None:
201
216
  input_ids = self.tokenizer.encode(input_text)
202
217
  else:
203
218
  input_ids = obj.input_ids
@@ -211,8 +226,10 @@ class TokenizerManager:
211
226
  return_logprob = obj.return_logprob
212
227
  logprob_start_len = obj.logprob_start_len
213
228
  top_logprobs_num = obj.top_logprobs_num
229
+ session_id = obj.session[0] if obj.session else None
230
+ session_rid = obj.session[1] if obj.session else None
214
231
 
215
- if len(input_ids) >= self.context_len:
232
+ if obj.input_ids is not None and len(input_ids) >= self.context_len:
216
233
  raise ValueError(
217
234
  f"The input ({len(input_ids)} tokens) is longer than the "
218
235
  f"model's context length ({self.context_len} tokens)."
@@ -235,7 +252,10 @@ class TokenizerManager:
235
252
  logprob_start_len,
236
253
  top_logprobs_num,
237
254
  obj.stream,
238
- obj.lora_path,
255
+ lora_path=obj.lora_path,
256
+ input_embeds=input_embeds,
257
+ session_id=session_id,
258
+ session_rid=session_rid,
239
259
  )
240
260
  elif isinstance(obj, EmbeddingReqInput):
241
261
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -451,6 +471,26 @@ class TokenizerManager:
451
471
  else:
452
472
  return False, "Another update is in progress. Please try again later."
453
473
 
474
+ async def open_session(
475
+ self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
476
+ ):
477
+ if self.to_create_loop:
478
+ self.create_handle_loop()
479
+
480
+ session_id = uuid.uuid4().hex
481
+ obj.session_id = session_id
482
+ self.send_to_scheduler.send_pyobj(obj)
483
+ self.session_futures[session_id] = asyncio.Future()
484
+ session_id = await self.session_futures[session_id]
485
+ del self.session_futures[session_id]
486
+ return session_id
487
+
488
+ async def close_session(
489
+ self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
490
+ ):
491
+ assert not self.to_create_loop, "close session should not be the first request"
492
+ await self.send_to_scheduler.send_pyobj(obj)
493
+
454
494
  def create_abort_task(self, obj: GenerateReqInput):
455
495
  # Abort the request if the client is disconnected.
456
496
  async def abort_request():
@@ -521,6 +561,11 @@ class TokenizerManager:
521
561
  if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
522
562
  self.mem_pool_size.set_result(self.mem_pool_size_tmp)
523
563
  continue
564
+ elif isinstance(recv_obj, OpenSessionReqOutput):
565
+ self.session_futures[recv_obj.session_id].set_result(
566
+ recv_obj.session_id
567
+ )
568
+ continue
524
569
 
525
570
  assert isinstance(
526
571
  recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
@@ -1,21 +1,20 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """A tensor parallel worker."""
17
15
 
18
16
  import logging
17
+ import threading
19
18
  from typing import Optional
20
19
 
21
20
  from sglang.srt.configs.model_config import ModelConfig
@@ -134,9 +133,19 @@ class TpModelWorker:
134
133
  self.model_runner.token_to_kv_pool,
135
134
  )
136
135
 
137
- def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
136
+ def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
137
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
138
+ self.model_runner.forward(forward_batch)
139
+
140
+ def forward_batch_generation(
141
+ self,
142
+ model_worker_batch: ModelWorkerBatch,
143
+ launch_done: Optional[threading.Event] = None,
144
+ ):
138
145
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
139
146
  logits_output = self.model_runner.forward(forward_batch)
147
+ if launch_done:
148
+ launch_done.set()
140
149
  next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
141
150
  return logits_output, next_token_ids
142
151
 
@@ -1,23 +1,21 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """A tensor parallel worker."""
17
15
 
16
+ import dataclasses
18
17
  import logging
19
18
  import threading
20
- import time
21
19
  from queue import Queue
22
20
  from typing import Optional
23
21
 
@@ -26,7 +24,6 @@ import torch
26
24
  from sglang.srt.managers.io_struct import UpdateWeightReqInput
27
25
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
28
26
  from sglang.srt.managers.tp_worker import TpModelWorker
29
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
27
  from sglang.srt.server_args import ServerArgs
31
28
 
32
29
  logger = logging.getLogger(__name__)
@@ -56,6 +53,7 @@ class TpModelWorkerClient:
56
53
  self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
57
54
  self.max_running_requests = self.worker.max_running_requests
58
55
  self.device = self.worker.device
56
+ self.gpu_id = gpu_id
59
57
 
60
58
  # Init future mappings
61
59
  self.future_token_ids_ct = 0
@@ -73,12 +71,6 @@ class TpModelWorkerClient:
73
71
  )
74
72
  self.forward_thread.start()
75
73
 
76
- self.copy_queue = Queue()
77
- self.copy_thread = threading.Thread(
78
- target=self.copy_thread_func,
79
- )
80
- self.copy_thread.start()
81
-
82
74
  def get_worker_info(self):
83
75
  return self.worker.get_worker_info()
84
76
 
@@ -98,15 +90,25 @@ class TpModelWorkerClient:
98
90
  with torch.cuda.stream(self.forward_stream):
99
91
  self.forward_thread_func_()
100
92
 
101
- @torch.inference_mode()
93
+ @torch.no_grad()
102
94
  def forward_thread_func_(self):
95
+ batch_pt = 0
96
+ batch_lists = [None] * 2
97
+
103
98
  while True:
104
- self.has_inflight_batch = False
105
99
  model_worker_batch, future_token_ids_ct = self.input_queue.get()
106
100
  if not model_worker_batch:
107
101
  break
108
- self.has_inflight_batch = True
109
- self.launch_event = threading.Event()
102
+
103
+ # Keep a reference of model_worker_batch by storing it into a list.
104
+ # Otherwise, the tensor members of model_worker_batch will be released
105
+ # by pytorch and cause CUDA illegal memory access errors.
106
+ batch_lists[batch_pt % 2] = model_worker_batch
107
+ batch_pt += 1
108
+
109
+ # Create event
110
+ self.launch_done = threading.Event()
111
+ copy_done = torch.cuda.Event()
110
112
 
111
113
  # Resolve future tokens in the input
112
114
  input_ids = model_worker_batch.input_ids
@@ -114,7 +116,7 @@ class TpModelWorkerClient:
114
116
 
115
117
  # Run forward
116
118
  logits_output, next_token_ids = self.worker.forward_batch_generation(
117
- model_worker_batch
119
+ model_worker_batch, self.launch_done
118
120
  )
119
121
 
120
122
  # Update the future token ids map
@@ -139,44 +141,45 @@ class TpModelWorkerClient:
139
141
  )
140
142
  )
141
143
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
142
- copy_event = torch.cuda.Event(blocking=True)
143
- copy_event.record()
144
+ copy_done.record()
144
145
 
145
- self.launch_event.set()
146
- self.copy_queue.put((copy_event, logits_output, next_token_ids))
146
+ self.output_queue.put((copy_done, logits_output, next_token_ids))
147
147
 
148
- def copy_thread_func(self):
149
- while True:
150
- copy_event, logits_output, next_token_ids = self.copy_queue.get()
151
- if not copy_event:
152
- break
153
- while not copy_event.query():
154
- time.sleep(1e-5)
148
+ def resolve_batch_result(self, bid: int):
149
+ copy_done, logits_output, next_token_ids = self.output_queue.get()
150
+ copy_done.synchronize()
151
+ self.launch_done.wait()
155
152
 
156
- if logits_output.next_token_logprobs is not None:
157
- logits_output.next_token_logprobs = (
158
- logits_output.next_token_logprobs.tolist()
153
+ if logits_output.next_token_logprobs is not None:
154
+ logits_output.next_token_logprobs = (
155
+ logits_output.next_token_logprobs.tolist()
156
+ )
157
+ if logits_output.input_token_logprobs is not None:
158
+ logits_output.input_token_logprobs = (
159
+ logits_output.input_token_logprobs.tolist()
159
160
  )
160
- if logits_output.input_token_logprobs is not None:
161
- logits_output.input_token_logprobs = (
162
- logits_output.input_token_logprobs.tolist()
163
- )
164
- logits_output.normalized_prompt_logprobs = (
165
- logits_output.normalized_prompt_logprobs.tolist()
166
- )
167
-
168
- self.output_queue.put((logits_output, next_token_ids.tolist()))
169
-
170
- def resulve_batch_result(self, bid: int):
171
- logits_output, next_token_ids = self.output_queue.get()
172
- if self.has_inflight_batch:
173
- # Wait until the batch is launched
174
- self.launch_event.wait()
161
+ logits_output.normalized_prompt_logprobs = (
162
+ logits_output.normalized_prompt_logprobs.tolist()
163
+ )
164
+ next_token_ids = next_token_ids.tolist()
175
165
  return logits_output, next_token_ids
176
166
 
177
167
  def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
168
+ # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
169
+ sampling_info = model_worker_batch.sampling_info
170
+ sampling_info.update_penalties()
171
+ model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
172
+ sampling_info,
173
+ sampling_info_done=threading.Event(),
174
+ scaling_penalties=sampling_info.scaling_penalties,
175
+ linear_penalties=sampling_info.linear_penalties,
176
+ )
177
+
178
+ # A cuda stream sync here to avoid the cuda illegal memory access error.
179
+ torch.cuda.current_stream().synchronize()
180
+
178
181
  # Push a new batch to the queue
179
- self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
182
+ self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
180
183
 
181
184
  # Allocate output future objects
182
185
  bs = len(model_worker_batch.seq_lens)
@@ -192,16 +195,8 @@ class TpModelWorkerClient:
192
195
  ) % self.future_token_ids_limit
193
196
  return None, future_next_token_ids
194
197
 
195
- def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
196
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
197
- logits_output = self.model_runner.forward(forward_batch)
198
- embeddings = logits_output.embeddings
199
- return embeddings
200
-
201
198
  def update_weights(self, recv_req: UpdateWeightReqInput):
202
- success, message = self.model_runner.update_weights(
203
- recv_req.model_path, recv_req.load_format
204
- )
199
+ success, message = self.worker.update_weights(recv_req)
205
200
  return success, message
206
201
 
207
202
  def __delete__(self):
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Utilities for Prometheus Metrics Collection."""
17
15
 
18
16
  from dataclasses import dataclass
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  Records the latency of some functions
18
16
  """
sglang/srt/mm_utils.py CHANGED
@@ -1,17 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
17
16
  """