sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -650,17 +650,19 @@ class GroupCoordinator:
|
|
650
650
|
output_size, dtype=input_.dtype, device=input_.device
|
651
651
|
)
|
652
652
|
|
653
|
+
# All-gather.
|
654
|
+
if input_.is_cpu and is_shm_available(
|
655
|
+
input_.dtype, self.world_size, self.local_size
|
656
|
+
):
|
657
|
+
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
658
|
+
|
653
659
|
if input_.is_cpu:
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
)
|
660
|
-
return output_tensor
|
660
|
+
torch.distributed.all_gather_into_tensor(
|
661
|
+
output_tensor, input_, group=self.device_group
|
662
|
+
)
|
663
|
+
else:
|
664
|
+
self.all_gather_into_tensor(output_tensor, input_)
|
661
665
|
|
662
|
-
# All-gather.
|
663
|
-
self.all_gather_into_tensor(output_tensor, input_)
|
664
666
|
# Reshape
|
665
667
|
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
666
668
|
output_tensor = output_tensor.movedim(0, dim)
|
@@ -0,0 +1,244 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Copied from vLLM
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
try:
|
11
|
+
from mcp import ClientSession
|
12
|
+
except ImportError:
|
13
|
+
logger.warning("Ignoring mcp import error")
|
14
|
+
|
15
|
+
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
16
|
+
|
17
|
+
from sglang.srt.entrypoints.harmony_utils import (
|
18
|
+
get_encoding,
|
19
|
+
get_streamable_parser_for_assistant,
|
20
|
+
render_for_completion,
|
21
|
+
)
|
22
|
+
from sglang.srt.entrypoints.tool import Tool
|
23
|
+
|
24
|
+
|
25
|
+
class ConversationContext(ABC):
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def append_output(self, output) -> None:
|
29
|
+
pass
|
30
|
+
|
31
|
+
@abstractmethod
|
32
|
+
async def call_tool(self) -> list[Message]:
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def need_builtin_tool_call(self) -> bool:
|
37
|
+
pass
|
38
|
+
|
39
|
+
@abstractmethod
|
40
|
+
def render_for_completion(self) -> list[int]:
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class SimpleContext(ConversationContext):
|
45
|
+
|
46
|
+
def __init__(self):
|
47
|
+
self.last_output = None
|
48
|
+
|
49
|
+
def append_output(self, output) -> None:
|
50
|
+
self.last_output = output
|
51
|
+
|
52
|
+
def need_builtin_tool_call(self) -> bool:
|
53
|
+
return False
|
54
|
+
|
55
|
+
async def call_tool(self) -> list[Message]:
|
56
|
+
raise NotImplementedError("Should not be called.")
|
57
|
+
|
58
|
+
def render_for_completion(self) -> list[int]:
|
59
|
+
raise NotImplementedError("Should not be called.")
|
60
|
+
|
61
|
+
|
62
|
+
class HarmonyContext(ConversationContext):
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
messages: list,
|
67
|
+
tool_sessions: dict[str, Union["ClientSession", Tool]],
|
68
|
+
):
|
69
|
+
# TODO: Remove the hack of Union[ClientSession, Tool] by using MCP
|
70
|
+
# when demo.
|
71
|
+
self._messages = messages
|
72
|
+
self.tool_sessions = tool_sessions
|
73
|
+
|
74
|
+
self.parser = get_streamable_parser_for_assistant()
|
75
|
+
self.num_init_messages = len(messages)
|
76
|
+
# TODO
|
77
|
+
self.num_prompt_tokens = 0
|
78
|
+
self.num_cached_tokens = 0
|
79
|
+
self.num_output_tokens = 0
|
80
|
+
self.num_reasoning_tokens = 0
|
81
|
+
|
82
|
+
def append_output(self, output) -> None:
|
83
|
+
if isinstance(output, dict) and "output_ids" in output:
|
84
|
+
output_token_ids = output["output_ids"]
|
85
|
+
|
86
|
+
# TODO: REMOVE here:
|
87
|
+
# Very hacky, find the first occurrence of token 200006 and cut from there
|
88
|
+
try:
|
89
|
+
start_index = output_token_ids.index(200006)
|
90
|
+
output_token_ids = output_token_ids[start_index:]
|
91
|
+
except ValueError:
|
92
|
+
pass
|
93
|
+
|
94
|
+
for token_id in output_token_ids:
|
95
|
+
self.parser.process(token_id)
|
96
|
+
output_msgs = self.parser.messages
|
97
|
+
|
98
|
+
meta_info = output["meta_info"]
|
99
|
+
|
100
|
+
if isinstance(meta_info, dict):
|
101
|
+
if "prompt_token_ids" in meta_info:
|
102
|
+
self.num_prompt_tokens = meta_info["prompt_tokens"]
|
103
|
+
if "cached_tokens" in meta_info:
|
104
|
+
self.num_cached_tokens = meta_info["cached_tokens"]
|
105
|
+
if "completion_tokens" in meta_info:
|
106
|
+
self.num_output_tokens += meta_info["completion_tokens"]
|
107
|
+
|
108
|
+
else:
|
109
|
+
output_msgs = output
|
110
|
+
|
111
|
+
self._messages.extend(output_msgs)
|
112
|
+
|
113
|
+
@property
|
114
|
+
def messages(self) -> list:
|
115
|
+
return self._messages
|
116
|
+
|
117
|
+
def need_builtin_tool_call(self) -> bool:
|
118
|
+
last_msg = self.messages[-1]
|
119
|
+
recipient = last_msg.recipient
|
120
|
+
return recipient is not None and (
|
121
|
+
recipient.startswith("browser.") or recipient.startswith("python")
|
122
|
+
)
|
123
|
+
|
124
|
+
async def call_tool(self) -> list[Message]:
|
125
|
+
if not self.messages:
|
126
|
+
return []
|
127
|
+
last_msg = self.messages[-1]
|
128
|
+
recipient = last_msg.recipient
|
129
|
+
if recipient is not None:
|
130
|
+
if recipient.startswith("browser."):
|
131
|
+
return await self.call_search_tool(
|
132
|
+
self.tool_sessions["browser"], last_msg
|
133
|
+
)
|
134
|
+
elif recipient.startswith("python"):
|
135
|
+
return await self.call_python_tool(
|
136
|
+
self.tool_sessions["python"], last_msg
|
137
|
+
)
|
138
|
+
raise ValueError("No tool call found")
|
139
|
+
|
140
|
+
def render_for_completion(self) -> list[int]:
|
141
|
+
return render_for_completion(self.messages)
|
142
|
+
|
143
|
+
async def call_search_tool(
|
144
|
+
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
145
|
+
) -> list[Message]:
|
146
|
+
if isinstance(tool_session, Tool):
|
147
|
+
return await tool_session.get_result(self)
|
148
|
+
tool_name = last_msg.recipient.split(".")[1]
|
149
|
+
args = json.loads(last_msg.content[0].text)
|
150
|
+
result = await tool_session.call_tool(tool_name, args)
|
151
|
+
result_str = result.content[0].text
|
152
|
+
content = TextContent(text=result_str)
|
153
|
+
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
154
|
+
return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]
|
155
|
+
|
156
|
+
async def call_python_tool(
|
157
|
+
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
158
|
+
) -> list[Message]:
|
159
|
+
if isinstance(tool_session, Tool):
|
160
|
+
return await tool_session.get_result(self)
|
161
|
+
param = {
|
162
|
+
"code": last_msg.content[0].text,
|
163
|
+
}
|
164
|
+
result = await tool_session.call_tool("python", param)
|
165
|
+
result_str = result.content[0].text
|
166
|
+
|
167
|
+
content = TextContent(text=result_str)
|
168
|
+
author = Author(role=Role.TOOL, name="python")
|
169
|
+
|
170
|
+
return [
|
171
|
+
Message(
|
172
|
+
author=author,
|
173
|
+
content=[content],
|
174
|
+
channel=last_msg.channel,
|
175
|
+
recipient=Role.ASSISTANT,
|
176
|
+
)
|
177
|
+
]
|
178
|
+
|
179
|
+
|
180
|
+
class StreamingHarmonyContext(HarmonyContext):
|
181
|
+
|
182
|
+
def __init__(self, *args, **kwargs):
|
183
|
+
super().__init__(*args, **kwargs)
|
184
|
+
self.last_output = None
|
185
|
+
|
186
|
+
self.parser = get_streamable_parser_for_assistant()
|
187
|
+
self.encoding = get_encoding()
|
188
|
+
self.last_tok = None
|
189
|
+
|
190
|
+
@property
|
191
|
+
def messages(self) -> list:
|
192
|
+
return self.parser.messages
|
193
|
+
|
194
|
+
def append_output(self, output) -> None:
|
195
|
+
if isinstance(output, dict) and "output_ids" in output:
|
196
|
+
# RequestOutput from SGLang with outputs
|
197
|
+
output_token_ids = output["output_ids"]
|
198
|
+
|
199
|
+
# TODO: REMOVE here:
|
200
|
+
# Very hacky, find the first occurrence of token 200006 and cut from there
|
201
|
+
# Find the first occurrence of token 200006 and cut from there
|
202
|
+
try:
|
203
|
+
start_index = output_token_ids.index(200006)
|
204
|
+
output_token_ids = output_token_ids[start_index:]
|
205
|
+
except ValueError:
|
206
|
+
pass
|
207
|
+
|
208
|
+
for token_id in output_token_ids:
|
209
|
+
self.parser.process(token_id)
|
210
|
+
|
211
|
+
else:
|
212
|
+
# Handle the case of tool output in direct message format
|
213
|
+
assert len(output) == 1, "Tool output should be a single message"
|
214
|
+
msg = output[0]
|
215
|
+
# Sometimes the recipient is not set for tool messages,
|
216
|
+
# so we set it to "assistant"
|
217
|
+
if msg.author.role == Role.TOOL and msg.recipient is None:
|
218
|
+
msg.recipient = "assistant"
|
219
|
+
toks = self.encoding.render(msg)
|
220
|
+
for tok in toks:
|
221
|
+
self.parser.process(tok)
|
222
|
+
self.last_tok = toks[-1]
|
223
|
+
|
224
|
+
def is_expecting_start(self) -> bool:
|
225
|
+
return self.parser.state == StreamState.EXPECT_START
|
226
|
+
|
227
|
+
def is_assistant_action_turn(self) -> bool:
|
228
|
+
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
|
229
|
+
|
230
|
+
def render_for_completion(self) -> list[int]:
|
231
|
+
# now this list of tokens as next turn's starting tokens
|
232
|
+
# `<|start|>assistant``,
|
233
|
+
# we need to process them in parser.
|
234
|
+
rendered_tokens = super().render_for_completion()
|
235
|
+
|
236
|
+
last_n = -1
|
237
|
+
to_process = []
|
238
|
+
while rendered_tokens[last_n] != self.last_tok:
|
239
|
+
to_process.append(rendered_tokens[last_n])
|
240
|
+
last_n -= 1
|
241
|
+
for tok in reversed(to_process):
|
242
|
+
self.parser.process(tok)
|
243
|
+
|
244
|
+
return rendered_tokens
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -492,12 +492,13 @@ class Engine(EngineBase):
|
|
492
492
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
493
493
|
)
|
494
494
|
|
495
|
-
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
495
|
+
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
|
496
496
|
"""Load a new LoRA adapter without re-launching the engine."""
|
497
497
|
|
498
498
|
obj = LoadLoRAAdapterReqInput(
|
499
499
|
lora_name=lora_name,
|
500
500
|
lora_path=lora_path,
|
501
|
+
pinned=pinned,
|
501
502
|
)
|
502
503
|
|
503
504
|
loop = asyncio.get_event_loop()
|
@@ -641,7 +642,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
641
642
|
if server_args.attention_backend == "flashinfer":
|
642
643
|
assert_pkg_version(
|
643
644
|
"flashinfer_python",
|
644
|
-
"0.2.
|
645
|
+
"0.2.10",
|
645
646
|
"Please uninstall the old version and "
|
646
647
|
"reinstall the latest version by following the instructions "
|
647
648
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -649,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
649
650
|
if _is_cuda:
|
650
651
|
assert_pkg_version(
|
651
652
|
"sgl-kernel",
|
652
|
-
"0.2
|
653
|
+
"0.3.2",
|
653
654
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
654
655
|
)
|
655
656
|
|
@@ -0,0 +1,370 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3
|
+
import datetime
|
4
|
+
import json
|
5
|
+
from collections.abc import Iterable
|
6
|
+
from typing import Literal, Optional, Union
|
7
|
+
|
8
|
+
from openai.types.responses import (
|
9
|
+
ResponseOutputItem,
|
10
|
+
ResponseOutputMessage,
|
11
|
+
ResponseOutputText,
|
12
|
+
ResponseReasoningItem,
|
13
|
+
)
|
14
|
+
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
15
|
+
from openai.types.responses.response_function_web_search import (
|
16
|
+
ActionFind,
|
17
|
+
ActionOpenPage,
|
18
|
+
ActionSearch,
|
19
|
+
ResponseFunctionWebSearch,
|
20
|
+
)
|
21
|
+
from openai.types.responses.response_reasoning_item import (
|
22
|
+
Content as ResponseReasoningTextContent,
|
23
|
+
)
|
24
|
+
from openai.types.responses.tool import Tool
|
25
|
+
from openai_harmony import (
|
26
|
+
Author,
|
27
|
+
Conversation,
|
28
|
+
DeveloperContent,
|
29
|
+
HarmonyEncodingName,
|
30
|
+
Message,
|
31
|
+
ReasoningEffort,
|
32
|
+
Role,
|
33
|
+
StreamableParser,
|
34
|
+
SystemContent,
|
35
|
+
TextContent,
|
36
|
+
ToolDescription,
|
37
|
+
load_harmony_encoding,
|
38
|
+
)
|
39
|
+
|
40
|
+
from sglang.srt.entrypoints.openai.protocol import ResponseInputOutputItem
|
41
|
+
from sglang.srt.utils import random_uuid
|
42
|
+
|
43
|
+
REASONING_EFFORT = {
|
44
|
+
"high": ReasoningEffort.HIGH,
|
45
|
+
"medium": ReasoningEffort.MEDIUM,
|
46
|
+
"low": ReasoningEffort.LOW,
|
47
|
+
}
|
48
|
+
|
49
|
+
_harmony_encoding = None
|
50
|
+
|
51
|
+
|
52
|
+
def get_encoding():
|
53
|
+
global _harmony_encoding
|
54
|
+
if _harmony_encoding is None:
|
55
|
+
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
56
|
+
return _harmony_encoding
|
57
|
+
|
58
|
+
|
59
|
+
def get_system_message(
|
60
|
+
model_identity: Optional[str] = None,
|
61
|
+
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
|
62
|
+
start_date: Optional[str] = None,
|
63
|
+
browser_description: Optional[str] = None,
|
64
|
+
python_description: Optional[str] = None,
|
65
|
+
) -> Message:
|
66
|
+
sys_msg_content = SystemContent.new()
|
67
|
+
if model_identity is not None:
|
68
|
+
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
69
|
+
if reasoning_effort is not None:
|
70
|
+
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
71
|
+
REASONING_EFFORT[reasoning_effort]
|
72
|
+
)
|
73
|
+
if start_date is None:
|
74
|
+
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
75
|
+
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
|
76
|
+
if browser_description is not None:
|
77
|
+
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
78
|
+
if python_description is not None:
|
79
|
+
sys_msg_content = sys_msg_content.with_tools(python_description)
|
80
|
+
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
81
|
+
return sys_msg
|
82
|
+
|
83
|
+
|
84
|
+
def get_developer_message(
|
85
|
+
instructions: Optional[str] = None, tools: Optional[list[Tool]] = None
|
86
|
+
) -> Message:
|
87
|
+
dev_msg_content = DeveloperContent.new()
|
88
|
+
if instructions is not None:
|
89
|
+
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
90
|
+
if tools is not None:
|
91
|
+
function_tools = []
|
92
|
+
for tool in tools:
|
93
|
+
if tool.type in ("web_search_preview", "code_interpreter"):
|
94
|
+
# These are built-in tools that are added to the system message.
|
95
|
+
pass
|
96
|
+
elif tool.type == "function":
|
97
|
+
function_tools.append(tool)
|
98
|
+
else:
|
99
|
+
raise ValueError(f"tool type {tool.type} not supported")
|
100
|
+
if function_tools:
|
101
|
+
function_tool_descriptions = [
|
102
|
+
ToolDescription.new(
|
103
|
+
name=tool.name,
|
104
|
+
description=tool.description,
|
105
|
+
parameters=tool.parameters,
|
106
|
+
)
|
107
|
+
for tool in function_tools
|
108
|
+
]
|
109
|
+
dev_msg_content = dev_msg_content.with_function_tools(
|
110
|
+
function_tool_descriptions
|
111
|
+
)
|
112
|
+
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
|
113
|
+
return dev_msg
|
114
|
+
|
115
|
+
|
116
|
+
def get_user_message(content: str) -> Message:
|
117
|
+
return Message.from_role_and_content(Role.USER, content)
|
118
|
+
|
119
|
+
|
120
|
+
def parse_response_input(
|
121
|
+
response_msg: ResponseInputOutputItem,
|
122
|
+
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]],
|
123
|
+
) -> Message:
|
124
|
+
if not isinstance(response_msg, dict):
|
125
|
+
response_msg = response_msg.model_dump()
|
126
|
+
if "type" not in response_msg or response_msg["type"] == "message":
|
127
|
+
role = response_msg["role"]
|
128
|
+
content = response_msg["content"]
|
129
|
+
if role == "system":
|
130
|
+
# User is trying to set a system message. Change it to:
|
131
|
+
# <|start|>developer<|message|># Instructions
|
132
|
+
# {instructions}<|end|>
|
133
|
+
role = "developer"
|
134
|
+
text_prefix = "Instructions:\n"
|
135
|
+
else:
|
136
|
+
text_prefix = ""
|
137
|
+
if isinstance(content, str):
|
138
|
+
msg = Message.from_role_and_content(role, text_prefix + content)
|
139
|
+
else:
|
140
|
+
contents = [TextContent(text=text_prefix + c["text"]) for c in content]
|
141
|
+
msg = Message.from_role_and_contents(role, contents)
|
142
|
+
elif response_msg["type"] == "function_call_output":
|
143
|
+
call_id = response_msg["call_id"]
|
144
|
+
call_response: Optional[ResponseFunctionToolCall] = None
|
145
|
+
for prev_response in reversed(prev_responses):
|
146
|
+
if (
|
147
|
+
isinstance(prev_response, ResponseFunctionToolCall)
|
148
|
+
and prev_response.call_id == call_id
|
149
|
+
):
|
150
|
+
call_response = prev_response
|
151
|
+
break
|
152
|
+
if call_response is None:
|
153
|
+
raise ValueError(f"No call message found for {call_id}")
|
154
|
+
msg = Message.from_author_and_content(
|
155
|
+
Author.new(Role.TOOL, f"functions.{call_response.name}"),
|
156
|
+
response_msg["output"],
|
157
|
+
)
|
158
|
+
elif response_msg["type"] == "reasoning":
|
159
|
+
content = response_msg["content"]
|
160
|
+
assert len(content) == 1
|
161
|
+
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
|
162
|
+
elif response_msg["type"] == "function_call":
|
163
|
+
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
|
164
|
+
msg = msg.with_channel("commentary")
|
165
|
+
msg = msg.with_recipient(f"functions.{response_msg['name']}")
|
166
|
+
msg = msg.with_content_type("json")
|
167
|
+
else:
|
168
|
+
raise ValueError(f"Unknown input type: {response_msg['type']}")
|
169
|
+
return msg
|
170
|
+
|
171
|
+
|
172
|
+
def parse_response_output(output: ResponseOutputItem) -> Message:
|
173
|
+
if isinstance(output, ResponseOutputMessage):
|
174
|
+
role = output.role
|
175
|
+
contents = [TextContent(text=c.text) for c in output.content]
|
176
|
+
msg = Message.from_role_and_contents(role, contents)
|
177
|
+
return msg
|
178
|
+
elif isinstance(output, ResponseFunctionToolCall):
|
179
|
+
msg = Message.from_role_and_content(Role.ASSISTANT, output.arguments)
|
180
|
+
msg = msg.with_channel("commentary")
|
181
|
+
msg = msg.with_recipient(output.name)
|
182
|
+
msg = msg.with_content_type("json")
|
183
|
+
return msg
|
184
|
+
else:
|
185
|
+
raise ValueError(f"Unknown output type: {type(output)}")
|
186
|
+
|
187
|
+
|
188
|
+
def parse_chat_input(chat_msg) -> Message:
|
189
|
+
role = chat_msg.role
|
190
|
+
content = chat_msg.content
|
191
|
+
if isinstance(content, str):
|
192
|
+
contents = [TextContent(text=content)]
|
193
|
+
else:
|
194
|
+
# TODO: Support refusal.
|
195
|
+
contents = [TextContent(text=c.text) for c in content]
|
196
|
+
msg = Message.from_role_and_contents(role, contents)
|
197
|
+
return msg
|
198
|
+
|
199
|
+
|
200
|
+
def render_for_completion(messages: list[Message]) -> list[int]:
|
201
|
+
conversation = Conversation.from_messages(messages)
|
202
|
+
token_ids = get_encoding().render_conversation_for_completion(
|
203
|
+
conversation, Role.ASSISTANT
|
204
|
+
)
|
205
|
+
return token_ids
|
206
|
+
|
207
|
+
|
208
|
+
def get_stop_tokens_for_assistant_actions() -> list[int]:
|
209
|
+
return get_encoding().stop_tokens_for_assistant_actions()
|
210
|
+
|
211
|
+
|
212
|
+
def get_streamable_parser_for_assistant() -> StreamableParser:
|
213
|
+
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
|
214
|
+
|
215
|
+
|
216
|
+
def parse_output_message(message: Message):
|
217
|
+
if message.author.role != "assistant":
|
218
|
+
# This is a message from a tool to the assistant (e.g., search result).
|
219
|
+
# Don't include it in the final output for now. This aligns with
|
220
|
+
# OpenAI's behavior on models like o4-mini.
|
221
|
+
return []
|
222
|
+
|
223
|
+
output_items = []
|
224
|
+
recipient = message.recipient
|
225
|
+
if recipient is not None and recipient.startswith("browser."):
|
226
|
+
if len(message.content) != 1:
|
227
|
+
raise ValueError("Invalid number of contents in browser message")
|
228
|
+
content = message.content[0]
|
229
|
+
browser_call = json.loads(content.text)
|
230
|
+
# TODO: translate to url properly!
|
231
|
+
if recipient == "browser.search":
|
232
|
+
action = ActionSearch(
|
233
|
+
query=f"cursor:{browser_call.get('query', '')}", type="search"
|
234
|
+
)
|
235
|
+
elif recipient == "browser.open":
|
236
|
+
action = ActionOpenPage(
|
237
|
+
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
|
238
|
+
)
|
239
|
+
elif recipient == "browser.find":
|
240
|
+
action = ActionFind(
|
241
|
+
pattern=browser_call["pattern"],
|
242
|
+
url=f"cursor:{browser_call.get('url', '')}",
|
243
|
+
type="find",
|
244
|
+
)
|
245
|
+
else:
|
246
|
+
raise ValueError(f"Unknown browser action: {recipient}")
|
247
|
+
web_search_item = ResponseFunctionWebSearch(
|
248
|
+
id=f"ws_{random_uuid()}",
|
249
|
+
action=action,
|
250
|
+
status="completed",
|
251
|
+
type="web_search_call",
|
252
|
+
)
|
253
|
+
output_items.append(web_search_item)
|
254
|
+
elif message.channel == "analysis":
|
255
|
+
for content in message.content:
|
256
|
+
reasoning_item = ResponseReasoningItem(
|
257
|
+
id=f"rs_{random_uuid()}",
|
258
|
+
type="reasoning",
|
259
|
+
summary=[],
|
260
|
+
content=[
|
261
|
+
ResponseReasoningTextContent(
|
262
|
+
text=content.text, type="reasoning_text"
|
263
|
+
)
|
264
|
+
],
|
265
|
+
status=None,
|
266
|
+
)
|
267
|
+
output_items.append(reasoning_item)
|
268
|
+
elif message.channel == "commentary":
|
269
|
+
if message.recipient.startswith("functions."):
|
270
|
+
function_name = message.recipient.split(".")[-1]
|
271
|
+
for content in message.content:
|
272
|
+
random_id = random_uuid()
|
273
|
+
response_item = ResponseFunctionToolCall(
|
274
|
+
arguments=content.text,
|
275
|
+
call_id=f"call_{random_id}",
|
276
|
+
type="function_call",
|
277
|
+
name=function_name,
|
278
|
+
id=f"ft_{random_id}",
|
279
|
+
)
|
280
|
+
output_items.append(response_item)
|
281
|
+
elif message.recipient.startswith("python") or message.recipient.startswith(
|
282
|
+
"browser"
|
283
|
+
):
|
284
|
+
for content in message.content:
|
285
|
+
reasoning_item = ResponseReasoningItem(
|
286
|
+
id=f"rs_{random_uuid()}",
|
287
|
+
type="reasoning",
|
288
|
+
summary=[],
|
289
|
+
content=[
|
290
|
+
ResponseReasoningTextContent(
|
291
|
+
text=content.text, type="reasoning_text"
|
292
|
+
)
|
293
|
+
],
|
294
|
+
status=None,
|
295
|
+
)
|
296
|
+
output_items.append(reasoning_item)
|
297
|
+
else:
|
298
|
+
raise ValueError(f"Unknown recipient: {message.recipient}")
|
299
|
+
elif message.channel == "final":
|
300
|
+
contents = []
|
301
|
+
for content in message.content:
|
302
|
+
output_text = ResponseOutputText(
|
303
|
+
text=content.text,
|
304
|
+
annotations=[], # TODO
|
305
|
+
type="output_text",
|
306
|
+
logprobs=None, # TODO
|
307
|
+
)
|
308
|
+
contents.append(output_text)
|
309
|
+
text_item = ResponseOutputMessage(
|
310
|
+
id=f"msg_{random_uuid()}",
|
311
|
+
content=contents,
|
312
|
+
role=message.author.role,
|
313
|
+
status="completed",
|
314
|
+
type="message",
|
315
|
+
)
|
316
|
+
output_items.append(text_item)
|
317
|
+
else:
|
318
|
+
raise ValueError(f"Unknown channel: {message.channel}")
|
319
|
+
return output_items
|
320
|
+
|
321
|
+
|
322
|
+
def parse_remaining_state(parser: StreamableParser):
|
323
|
+
if not parser.current_content:
|
324
|
+
return []
|
325
|
+
if parser.current_role != Role.ASSISTANT:
|
326
|
+
return []
|
327
|
+
current_recipient = parser.current_recipient
|
328
|
+
if current_recipient is not None and current_recipient.startswith("browser."):
|
329
|
+
return []
|
330
|
+
|
331
|
+
if parser.current_channel == "analysis":
|
332
|
+
reasoning_item = ResponseReasoningItem(
|
333
|
+
id=f"rs_{random_uuid()}",
|
334
|
+
type="reasoning",
|
335
|
+
summary=[],
|
336
|
+
content=[
|
337
|
+
ResponseReasoningTextContent(
|
338
|
+
text=parser.current_content, type="reasoning_text"
|
339
|
+
)
|
340
|
+
],
|
341
|
+
status=None,
|
342
|
+
)
|
343
|
+
return [reasoning_item]
|
344
|
+
elif parser.current_channel == "final":
|
345
|
+
output_text = ResponseOutputText(
|
346
|
+
content=[
|
347
|
+
ResponseReasoningTextContent(
|
348
|
+
text=parser.current_content, type="reasoning_text"
|
349
|
+
)
|
350
|
+
],
|
351
|
+
annotations=[], # TODO
|
352
|
+
type="output_text",
|
353
|
+
logprobs=None, # TODO
|
354
|
+
)
|
355
|
+
text_item = ResponseOutputMessage(
|
356
|
+
id=f"msg_{random_uuid()}",
|
357
|
+
content=[output_text],
|
358
|
+
role="assistant",
|
359
|
+
status="completed",
|
360
|
+
type="message",
|
361
|
+
)
|
362
|
+
return [text_item]
|
363
|
+
return []
|
364
|
+
|
365
|
+
|
366
|
+
def parse_output_into_messages(token_ids: Iterable[int]):
|
367
|
+
parser = get_streamable_parser_for_assistant()
|
368
|
+
for token_id in token_ids:
|
369
|
+
parser.process(token_id)
|
370
|
+
return parser
|