sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@ import triton
|
|
|
11
11
|
import triton.language as tl
|
|
12
12
|
|
|
13
13
|
from sglang.srt.custom_op import CustomOp
|
|
14
|
+
from sglang.srt.server_args import get_global_server_args
|
|
14
15
|
from sglang.srt.utils import (
|
|
15
16
|
cpu_has_amx_support,
|
|
16
17
|
get_bool_env_var,
|
|
@@ -124,18 +125,34 @@ class RotaryEmbedding(CustomOp):
|
|
|
124
125
|
self.cos_sin_cache: torch.Tensor
|
|
125
126
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
126
127
|
|
|
128
|
+
self._apply_rotary_emb_wrapped = _apply_rotary_emb
|
|
129
|
+
|
|
130
|
+
if get_global_server_args().rl_on_policy_target == "fsdp":
|
|
131
|
+
self._forward_method = self.forward_native
|
|
132
|
+
self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)(
|
|
133
|
+
self._apply_rotary_emb_wrapped
|
|
134
|
+
)
|
|
135
|
+
|
|
127
136
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
|
128
137
|
"""Compute the inverse frequency."""
|
|
129
138
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
|
130
139
|
# use CPU to compute the cache and then move it to GPU. However, we
|
|
131
140
|
# create the cache on GPU for faster initialization. This may cause
|
|
132
141
|
# a slight numerical difference between the HF implementation and ours.
|
|
142
|
+
init_device = (
|
|
143
|
+
"cpu" if get_global_server_args().rl_on_policy_target == "fsdp" else None
|
|
144
|
+
)
|
|
133
145
|
inv_freq = 1.0 / (
|
|
134
146
|
base
|
|
135
147
|
** (
|
|
136
|
-
torch.arange(
|
|
148
|
+
torch.arange(
|
|
149
|
+
0, self.rotary_dim, 2, dtype=torch.float, device=init_device
|
|
150
|
+
)
|
|
151
|
+
/ self.rotary_dim
|
|
137
152
|
)
|
|
138
153
|
)
|
|
154
|
+
if get_global_server_args().rl_on_policy_target == "fsdp":
|
|
155
|
+
inv_freq = inv_freq.cuda()
|
|
139
156
|
return inv_freq
|
|
140
157
|
|
|
141
158
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
|
@@ -173,14 +190,16 @@ class RotaryEmbedding(CustomOp):
|
|
|
173
190
|
query = query.view(num_tokens, -1, self.head_size)
|
|
174
191
|
query_rot = query[..., : self.rotary_dim]
|
|
175
192
|
query_pass = query[..., self.rotary_dim :]
|
|
176
|
-
query_rot =
|
|
193
|
+
query_rot = self._apply_rotary_emb_wrapped(
|
|
194
|
+
query_rot, cos, sin, self.is_neox_style
|
|
195
|
+
)
|
|
177
196
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
|
178
197
|
|
|
179
198
|
key_shape = key.shape
|
|
180
199
|
key = key.view(num_tokens, -1, self.head_size)
|
|
181
200
|
key_rot = key[..., : self.rotary_dim]
|
|
182
201
|
key_pass = key[..., self.rotary_dim :]
|
|
183
|
-
key_rot =
|
|
202
|
+
key_rot = self._apply_rotary_emb_wrapped(key_rot, cos, sin, self.is_neox_style)
|
|
184
203
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
|
185
204
|
return query, key
|
|
186
205
|
|
|
@@ -300,10 +319,20 @@ class RotaryEmbedding(CustomOp):
|
|
|
300
319
|
query: torch.Tensor,
|
|
301
320
|
key: torch.Tensor,
|
|
302
321
|
offsets: Optional[torch.Tensor] = None,
|
|
322
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
|
303
323
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
324
|
+
assert (
|
|
325
|
+
fused_set_kv_buffer_arg is None
|
|
326
|
+
), "fused_set_kv_buffer_arg is not supported for xpu implementation"
|
|
327
|
+
positions = torch.add(positions, offsets) if offsets is not None else positions
|
|
328
|
+
return torch.ops.sgl_kernel.rotary_embedding(
|
|
329
|
+
positions,
|
|
330
|
+
query,
|
|
331
|
+
key,
|
|
332
|
+
self.head_size,
|
|
333
|
+
self.cos_sin_cache,
|
|
334
|
+
self.is_neox_style,
|
|
335
|
+
)
|
|
307
336
|
|
|
308
337
|
|
|
309
338
|
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
|
@@ -1058,6 +1087,7 @@ def _triton_mrope_forward(
|
|
|
1058
1087
|
mrope_section_h: tl.constexpr,
|
|
1059
1088
|
mrope_section_w: tl.constexpr,
|
|
1060
1089
|
is_interleaved: tl.constexpr,
|
|
1090
|
+
is_neox_style: tl.constexpr,
|
|
1061
1091
|
):
|
|
1062
1092
|
# Adapted from
|
|
1063
1093
|
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
|
|
@@ -1112,51 +1142,99 @@ def _triton_mrope_forward(
|
|
|
1112
1142
|
# program instance (i.e. for the current token) separately
|
|
1113
1143
|
# ####################################################################
|
|
1114
1144
|
# left half of the head
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
tl.arange(0,
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
tl.arange(0,
|
|
1126
|
-
|
|
1145
|
+
if is_neox_style:
|
|
1146
|
+
first_half_q_offsets = (
|
|
1147
|
+
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
1148
|
+
)
|
|
1149
|
+
first_half_k_offsets = (
|
|
1150
|
+
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
1151
|
+
)
|
|
1152
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
|
1153
|
+
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
|
|
1154
|
+
)
|
|
1155
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
|
1156
|
+
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
|
|
1157
|
+
)
|
|
1127
1158
|
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1159
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
|
1160
|
+
sin_row.dtype
|
|
1161
|
+
)
|
|
1162
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
|
1163
|
+
sin_row.dtype
|
|
1164
|
+
)
|
|
1134
1165
|
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1166
|
+
# right half of the head
|
|
1167
|
+
second_half_q_offsets = first_half_q_offsets + (rd // 2)
|
|
1168
|
+
second_half_k_offsets = first_half_k_offsets + (rd // 2)
|
|
1169
|
+
second_q_mask = first_q_mask
|
|
1170
|
+
second_k_mask = first_k_mask
|
|
1171
|
+
|
|
1172
|
+
q_tile_2 = tl.load(
|
|
1173
|
+
q_ptr + second_half_q_offsets, mask=second_q_mask, other=0
|
|
1174
|
+
).to(sin_row.dtype)
|
|
1175
|
+
k_tile_2 = tl.load(
|
|
1176
|
+
k_ptr + second_half_k_offsets, mask=second_k_mask, other=0
|
|
1177
|
+
).to(sin_row.dtype)
|
|
1178
|
+
|
|
1179
|
+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
|
1180
|
+
# Since cos and sin are now half-size,
|
|
1181
|
+
# we use the same cos_row and sin_row for both halves
|
|
1182
|
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
|
1183
|
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
|
1184
|
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
|
1185
|
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
|
1186
|
+
|
|
1187
|
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
|
1188
|
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
|
1189
|
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
|
1190
|
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
1191
|
+
else:
|
|
1192
|
+
base_q = tl.arange(0, pad_n_qh)[:, None] * hd
|
|
1193
|
+
base_k = tl.arange(0, pad_n_kh)[:, None] * hd
|
|
1194
|
+
even_idx = 2 * tl.arange(0, pad_hd // 2)[None, :]
|
|
1195
|
+
odd_idx = even_idx + 1
|
|
1196
|
+
|
|
1197
|
+
even_q_offsets = base_q + even_idx
|
|
1198
|
+
odd_q_offsets = base_q + odd_idx
|
|
1199
|
+
even_k_offsets = base_k + even_idx
|
|
1200
|
+
odd_k_offsets = base_k + odd_idx
|
|
1201
|
+
|
|
1202
|
+
idx_mask = tl.arange(0, pad_hd // 2)[None, :] < (rd // 2)
|
|
1203
|
+
qn_mask = tl.arange(0, pad_n_qh)[:, None] < n_qh
|
|
1204
|
+
kn_mask = tl.arange(0, pad_n_kh)[:, None] < n_kh
|
|
1205
|
+
|
|
1206
|
+
even_q_mask = qn_mask & idx_mask
|
|
1207
|
+
odd_q_mask = qn_mask & idx_mask
|
|
1208
|
+
even_k_mask = kn_mask & idx_mask
|
|
1209
|
+
odd_k_mask = kn_mask & idx_mask
|
|
1210
|
+
|
|
1211
|
+
q_tile_1 = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to(
|
|
1212
|
+
sin_row.dtype
|
|
1213
|
+
)
|
|
1214
|
+
k_tile_1 = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to(
|
|
1215
|
+
sin_row.dtype
|
|
1216
|
+
)
|
|
1140
1217
|
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1218
|
+
q_tile_2 = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to(
|
|
1219
|
+
sin_row.dtype
|
|
1220
|
+
)
|
|
1221
|
+
k_tile_2 = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to(
|
|
1222
|
+
sin_row.dtype
|
|
1223
|
+
)
|
|
1147
1224
|
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1225
|
+
# y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
|
|
1226
|
+
# NeoX-style rotary embedding:
|
|
1227
|
+
# Each (even, odd) channel pair forms one rotation arm.
|
|
1228
|
+
# cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
|
|
1229
|
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
|
1230
|
+
tl.store(q_ptr + even_q_offsets, new_q_tile_1, mask=even_q_mask)
|
|
1231
|
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
|
1232
|
+
tl.store(q_ptr + odd_q_offsets, new_q_tile_2, mask=odd_q_mask)
|
|
1155
1233
|
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1234
|
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
|
1235
|
+
tl.store(k_ptr + even_k_offsets, new_k_tile_1, mask=even_k_mask)
|
|
1236
|
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
|
1237
|
+
tl.store(k_ptr + odd_k_offsets, new_k_tile_2, mask=odd_k_mask)
|
|
1160
1238
|
|
|
1161
1239
|
|
|
1162
1240
|
def triton_mrope(
|
|
@@ -1168,6 +1246,7 @@ def triton_mrope(
|
|
|
1168
1246
|
head_size: int,
|
|
1169
1247
|
rotary_dim: int,
|
|
1170
1248
|
mrope_interleaved: bool,
|
|
1249
|
+
is_neox_style: bool,
|
|
1171
1250
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1172
1251
|
"""The mrope triton kernel.
|
|
1173
1252
|
|
|
@@ -1218,6 +1297,7 @@ def triton_mrope(
|
|
|
1218
1297
|
mrope_section[1],
|
|
1219
1298
|
mrope_section[2],
|
|
1220
1299
|
mrope_interleaved,
|
|
1300
|
+
is_neox_style,
|
|
1221
1301
|
)
|
|
1222
1302
|
return q, k
|
|
1223
1303
|
|
|
@@ -1361,6 +1441,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
|
1361
1441
|
else:
|
|
1362
1442
|
return self._forward_native(positions, query, key)
|
|
1363
1443
|
|
|
1444
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
1364
1445
|
def _forward_triton(
|
|
1365
1446
|
self,
|
|
1366
1447
|
positions: torch.Tensor,
|
|
@@ -1379,6 +1460,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
|
1379
1460
|
if positions.ndim == 2:
|
|
1380
1461
|
assert self.mrope_section
|
|
1381
1462
|
|
|
1463
|
+
torch._dynamo.graph_break()
|
|
1382
1464
|
q, k = triton_mrope(
|
|
1383
1465
|
query,
|
|
1384
1466
|
key,
|
|
@@ -1388,7 +1470,9 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
|
1388
1470
|
self.head_size,
|
|
1389
1471
|
self.rotary_dim,
|
|
1390
1472
|
self.mrope_interleaved,
|
|
1473
|
+
self.is_neox_style,
|
|
1391
1474
|
)
|
|
1475
|
+
torch._dynamo.graph_break()
|
|
1392
1476
|
|
|
1393
1477
|
return q.reshape(query_shape), k.reshape(key_shape)
|
|
1394
1478
|
|
sglang/srt/layers/sampler.py
CHANGED
|
@@ -102,6 +102,14 @@ class Sampler(nn.Module):
|
|
|
102
102
|
if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
|
|
103
103
|
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
|
|
104
104
|
|
|
105
|
+
if get_global_server_args().rl_on_policy_target == "fsdp":
|
|
106
|
+
logits_div_temperature = (
|
|
107
|
+
logits.bfloat16().div(sampling_info.temperatures).bfloat16()
|
|
108
|
+
)
|
|
109
|
+
logprobs_via_logsoftmax_kernel = torch.log_softmax(
|
|
110
|
+
logits_div_temperature, dim=-1
|
|
111
|
+
)
|
|
112
|
+
|
|
105
113
|
# Post process logits
|
|
106
114
|
logits.div_(sampling_info.temperatures)
|
|
107
115
|
logits[:] = torch.softmax(logits, dim=-1)
|
|
@@ -148,8 +156,11 @@ class Sampler(nn.Module):
|
|
|
148
156
|
)
|
|
149
157
|
|
|
150
158
|
if return_logprob:
|
|
159
|
+
if get_global_server_args().rl_on_policy_target == "fsdp":
|
|
160
|
+
logprobs = logprobs_via_logsoftmax_kernel
|
|
161
|
+
del logprobs_via_logsoftmax_kernel
|
|
151
162
|
# clamp to avoid -inf
|
|
152
|
-
|
|
163
|
+
elif SGLANG_RETURN_ORIGINAL_LOGPROB:
|
|
153
164
|
logprobs = torch.log(probs_without_temp_scaling).clamp(
|
|
154
165
|
min=torch.finfo(probs_without_temp_scaling.dtype).min
|
|
155
166
|
)
|
sglang/srt/lora/lora_registry.py
CHANGED
|
@@ -205,3 +205,12 @@ class LoRARegistry:
|
|
|
205
205
|
Returns the total number of LoRA adapters currently registered.
|
|
206
206
|
"""
|
|
207
207
|
return len(self._registry)
|
|
208
|
+
|
|
209
|
+
def get_all_adapters(self) -> Dict[str, LoRARef]:
|
|
210
|
+
"""
|
|
211
|
+
Returns a dictionary of all registered LoRA adapters.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Dict[str, LoRARef]: A dictionary mapping LoRA names to LoRARef objects.
|
|
215
|
+
"""
|
|
216
|
+
return dict(self._registry)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Any, Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AsyncMMDataProcessor:
|
|
11
|
+
"""
|
|
12
|
+
Async wrapper for a multimodal processor.
|
|
13
|
+
|
|
14
|
+
Behavior:
|
|
15
|
+
- If the underlying processor exposes `process_mm_data_async`, call/await it directly.
|
|
16
|
+
- Otherwise, fall back to running a synchronous `process_mm_data` in a thread pool.
|
|
17
|
+
- Optionally guard per-call concurrency via an asyncio.Semaphore.
|
|
18
|
+
- Optionally enforce per-call timeout via asyncio.wait_for.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
mm_processor: Any,
|
|
24
|
+
*,
|
|
25
|
+
max_concurrent_calls: Optional[int] = None,
|
|
26
|
+
timeout_s: Optional[float] = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Args:
|
|
30
|
+
mm_processor: An object exposing either
|
|
31
|
+
- async def process_mm_data_async(...): -> Dict[str, Any]
|
|
32
|
+
or
|
|
33
|
+
- def process_mm_data(...): -> Dict[str, Any]
|
|
34
|
+
max_concurrent_calls: Optional concurrency cap for per-call execution.
|
|
35
|
+
timeout_s: Optional timeout (seconds) for each `process()` call.
|
|
36
|
+
"""
|
|
37
|
+
self.mm_processor = mm_processor
|
|
38
|
+
self.timeout_s = timeout_s
|
|
39
|
+
|
|
40
|
+
# Concurrency guard (None -> unlimited)
|
|
41
|
+
self.semaphore = (
|
|
42
|
+
asyncio.Semaphore(max_concurrent_calls) if max_concurrent_calls else None
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Detect async path; if missing, prepare a fallback executor for sync path
|
|
46
|
+
self._proc_async = getattr(mm_processor, "process_mm_data_async", None)
|
|
47
|
+
self.is_async = asyncio.iscoroutinefunction(self._proc_async)
|
|
48
|
+
self.fallback_exec: Optional[ThreadPoolExecutor] = (
|
|
49
|
+
ThreadPoolExecutor(max_workers=max_concurrent_calls)
|
|
50
|
+
if not self.is_async
|
|
51
|
+
else None
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
async def process(
|
|
55
|
+
self,
|
|
56
|
+
*,
|
|
57
|
+
image_data: Optional[List[Union[str, bytes]]] = None,
|
|
58
|
+
audio_data: Optional[List[Union[str, bytes]]] = None,
|
|
59
|
+
input_text_or_ids: Union[str, List[int], None] = None,
|
|
60
|
+
request_obj: Any,
|
|
61
|
+
**kwargs: Any,
|
|
62
|
+
) -> Dict[str, Any]:
|
|
63
|
+
"""
|
|
64
|
+
Public entrypoint: process a single multimodal request without blocking the event loop.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
async def _invoke() -> Dict[str, Any]:
|
|
68
|
+
if self.is_async:
|
|
69
|
+
# Native async implementation
|
|
70
|
+
return await self._proc_async(
|
|
71
|
+
image_data=image_data,
|
|
72
|
+
audio_data=audio_data,
|
|
73
|
+
input_text=input_text_or_ids,
|
|
74
|
+
request_obj=request_obj,
|
|
75
|
+
**kwargs,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Synchronous fallback
|
|
79
|
+
sync_fn = getattr(self.mm_processor, "process_mm_data", None)
|
|
80
|
+
if not callable(sync_fn):
|
|
81
|
+
raise RuntimeError(
|
|
82
|
+
"mm_processor has neither 'process_mm_data_async' nor 'process_mm_data'."
|
|
83
|
+
)
|
|
84
|
+
loop = asyncio.get_running_loop()
|
|
85
|
+
fn = partial(
|
|
86
|
+
sync_fn,
|
|
87
|
+
image_data=image_data,
|
|
88
|
+
audio_data=audio_data,
|
|
89
|
+
input_text=input_text_or_ids,
|
|
90
|
+
request_obj=request_obj,
|
|
91
|
+
**kwargs,
|
|
92
|
+
)
|
|
93
|
+
return await loop.run_in_executor(self.fallback_exec, fn)
|
|
94
|
+
|
|
95
|
+
# Apply optional concurrency guard
|
|
96
|
+
if self.semaphore is not None:
|
|
97
|
+
async with self.semaphore:
|
|
98
|
+
if self.timeout_s is not None:
|
|
99
|
+
return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
|
|
100
|
+
return await _invoke()
|
|
101
|
+
|
|
102
|
+
# No concurrency guard
|
|
103
|
+
if self.timeout_s is not None:
|
|
104
|
+
return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
|
|
105
|
+
return await _invoke()
|
|
106
|
+
|
|
107
|
+
def shutdown(self) -> None:
|
|
108
|
+
"""Gracefully shutdown resources owned by this wrapper."""
|
|
109
|
+
try:
|
|
110
|
+
if self.fallback_exec:
|
|
111
|
+
self.fallback_exec.shutdown(wait=False)
|
|
112
|
+
except Exception:
|
|
113
|
+
logger.exception(
|
|
114
|
+
"Error while shutting down fallback executor in AsyncMMDataProcessor"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def __del__(self):
|
|
118
|
+
# Best-effort shutdown
|
|
119
|
+
try:
|
|
120
|
+
self.shutdown()
|
|
121
|
+
except Exception:
|
|
122
|
+
pass
|
|
@@ -34,13 +34,21 @@ from sglang.srt.managers.io_struct import (
|
|
|
34
34
|
TokenizedGenerateReqInput,
|
|
35
35
|
WatchLoadUpdateReq,
|
|
36
36
|
)
|
|
37
|
-
from sglang.srt.managers.schedule_batch import Req
|
|
37
|
+
from sglang.srt.managers.schedule_batch import Req, RequestStage
|
|
38
38
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
|
39
39
|
from sglang.srt.server_args import (
|
|
40
40
|
DP_ATTENTION_HANDSHAKE_PORT_DELTA,
|
|
41
41
|
PortArgs,
|
|
42
42
|
ServerArgs,
|
|
43
43
|
)
|
|
44
|
+
from sglang.srt.tracing.trace import (
|
|
45
|
+
process_tracing_init,
|
|
46
|
+
trace_get_proc_propagate_context,
|
|
47
|
+
trace_set_proc_propagate_context,
|
|
48
|
+
trace_set_thread_info,
|
|
49
|
+
trace_slice_end,
|
|
50
|
+
trace_slice_start,
|
|
51
|
+
)
|
|
44
52
|
from sglang.srt.utils import (
|
|
45
53
|
bind_port,
|
|
46
54
|
configure_logger,
|
|
@@ -170,11 +178,22 @@ class DataParallelController:
|
|
|
170
178
|
def handle_load_update_req(self, obj):
|
|
171
179
|
self.dp_budget.update_budget(obj)
|
|
172
180
|
|
|
181
|
+
def dispatching_with_trace(self, req: Req):
|
|
182
|
+
if self.server_args.enable_trace:
|
|
183
|
+
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
|
184
|
+
trace_slice_start(RequestStage.DC_DISPATCH, req.rid)
|
|
185
|
+
req.trace_context = trace_get_proc_propagate_context(req.rid)
|
|
186
|
+
|
|
187
|
+
self.dispatching(req)
|
|
188
|
+
|
|
189
|
+
if self.server_args.enable_trace:
|
|
190
|
+
trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True)
|
|
191
|
+
|
|
173
192
|
def init_dispatcher(self):
|
|
174
193
|
self._request_dispatcher = TypeBasedDispatcher(
|
|
175
194
|
[
|
|
176
|
-
(TokenizedGenerateReqInput, self.
|
|
177
|
-
(TokenizedEmbeddingReqInput, self.
|
|
195
|
+
(TokenizedGenerateReqInput, self.dispatching_with_trace),
|
|
196
|
+
(TokenizedEmbeddingReqInput, self.dispatching_with_trace),
|
|
178
197
|
(BlockReqInput, self.send_to_all_workers),
|
|
179
198
|
(WatchLoadUpdateReq, self.handle_load_update_req),
|
|
180
199
|
]
|
|
@@ -487,6 +506,14 @@ def run_data_parallel_controller_process(
|
|
|
487
506
|
pipe_writer,
|
|
488
507
|
):
|
|
489
508
|
kill_itself_when_parent_died()
|
|
509
|
+
if server_args.enable_trace:
|
|
510
|
+
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
|
|
511
|
+
thread_label = "DP Controller"
|
|
512
|
+
if server_args.disaggregation_mode == "prefill":
|
|
513
|
+
thread_label = "Prefill DP Controller"
|
|
514
|
+
elif server_args.disaggregation_mode == "decode":
|
|
515
|
+
thread_label = "Decode DP Controller"
|
|
516
|
+
trace_set_thread_info(thread_label)
|
|
490
517
|
setproctitle.setproctitle("sglang::data_parallel_controller")
|
|
491
518
|
faulthandler.enable()
|
|
492
519
|
configure_logger(server_args)
|
|
@@ -235,6 +235,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
|
235
235
|
new_text = ""
|
|
236
236
|
else:
|
|
237
237
|
new_text = find_printable_text(new_text)
|
|
238
|
+
else:
|
|
239
|
+
del self.decode_status[recv_obj.rids[i]]
|
|
238
240
|
|
|
239
241
|
output_str = self.trim_matched_stop(
|
|
240
242
|
s.decoded_text + new_text,
|
|
@@ -273,6 +275,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
|
273
275
|
output_hidden_states=recv_obj.output_hidden_states,
|
|
274
276
|
placeholder_tokens_idx=None,
|
|
275
277
|
placeholder_tokens_val=None,
|
|
278
|
+
retraction_counts=recv_obj.retraction_counts,
|
|
276
279
|
token_steps=recv_obj.token_steps,
|
|
277
280
|
)
|
|
278
281
|
|
sglang/srt/managers/io_struct.py
CHANGED
|
@@ -574,6 +574,7 @@ class GenerateReqInput(BaseReq):
|
|
|
574
574
|
custom_labels=self.custom_labels,
|
|
575
575
|
return_bytes=self.return_bytes,
|
|
576
576
|
return_entropy=self.return_entropy,
|
|
577
|
+
http_worker_ipc=self.http_worker_ipc,
|
|
577
578
|
)
|
|
578
579
|
|
|
579
580
|
|
|
@@ -694,6 +695,9 @@ class EmbeddingReqInput(BaseReq):
|
|
|
694
695
|
# tracing context
|
|
695
696
|
trace_context: Optional[Dict] = None
|
|
696
697
|
|
|
698
|
+
# The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
|
|
699
|
+
dimensions: Optional[int] = None
|
|
700
|
+
|
|
697
701
|
def normalize_batch_and_arguments(self):
|
|
698
702
|
# at least one of text, input_ids, or image should be provided
|
|
699
703
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
|
@@ -759,6 +763,7 @@ class EmbeddingReqInput(BaseReq):
|
|
|
759
763
|
sampling_params=self.sampling_params[i],
|
|
760
764
|
rid=self.rid[i],
|
|
761
765
|
is_cross_encoder_request=True,
|
|
766
|
+
http_worker_ipc=self.http_worker_ipc,
|
|
762
767
|
)
|
|
763
768
|
|
|
764
769
|
return EmbeddingReqInput(
|
|
@@ -769,6 +774,8 @@ class EmbeddingReqInput(BaseReq):
|
|
|
769
774
|
video_data=self.video_data[i] if self.video_data is not None else None,
|
|
770
775
|
sampling_params=self.sampling_params[i],
|
|
771
776
|
rid=self.rid[i],
|
|
777
|
+
dimensions=self.dimensions,
|
|
778
|
+
http_worker_ipc=self.http_worker_ipc,
|
|
772
779
|
)
|
|
773
780
|
|
|
774
781
|
|
|
@@ -788,6 +795,8 @@ class TokenizedEmbeddingReqInput(BaseReq):
|
|
|
788
795
|
data_parallel_rank: Optional[int] = None
|
|
789
796
|
# Priority for the request
|
|
790
797
|
priority: Optional[int] = None
|
|
798
|
+
# The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
|
|
799
|
+
dimensions: Optional[int] = None
|
|
791
800
|
|
|
792
801
|
|
|
793
802
|
@dataclass
|
|
@@ -851,6 +860,9 @@ class BatchTokenIDOutput(BaseBatchReq):
|
|
|
851
860
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
|
852
861
|
placeholder_tokens_val: List[Optional[List[int]]]
|
|
853
862
|
|
|
863
|
+
# Number of times each request was retracted.
|
|
864
|
+
retraction_counts: List[int]
|
|
865
|
+
|
|
854
866
|
# The trainer step id. Used to know which step's weights are used for sampling.
|
|
855
867
|
token_steps: List[List[int]] = None
|
|
856
868
|
|
|
@@ -927,6 +939,9 @@ class BatchStrOutput(BaseBatchReq):
|
|
|
927
939
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
|
928
940
|
placeholder_tokens_val: List[Optional[List[int]]]
|
|
929
941
|
|
|
942
|
+
# Number of times each request was retracted.
|
|
943
|
+
retraction_counts: List[int]
|
|
944
|
+
|
|
930
945
|
# The trainer step id. Used to know which step's weights are used for sampling.
|
|
931
946
|
token_steps: List[List[int]] = None
|
|
932
947
|
|
|
@@ -969,6 +984,9 @@ class BatchEmbeddingOutput(BaseBatchReq):
|
|
|
969
984
|
placeholder_tokens_idx: List[Optional[List[int]]]
|
|
970
985
|
placeholder_tokens_val: List[Optional[List[int]]]
|
|
971
986
|
|
|
987
|
+
# Number of times each request was retracted.
|
|
988
|
+
retraction_counts: List[int]
|
|
989
|
+
|
|
972
990
|
|
|
973
991
|
@dataclass
|
|
974
992
|
class ClearHiCacheReqInput(BaseReq):
|
|
@@ -1212,7 +1230,7 @@ class AbortReq(BaseReq):
|
|
|
1212
1230
|
abort_all: bool = False
|
|
1213
1231
|
# The finished reason data
|
|
1214
1232
|
finished_reason: Optional[Dict[str, Any]] = None
|
|
1215
|
-
|
|
1233
|
+
abort_message: Optional[str] = None
|
|
1216
1234
|
|
|
1217
1235
|
def __post_init__(self):
|
|
1218
1236
|
# FIXME: This is a hack to keep the same with the old code
|
|
@@ -1455,6 +1473,16 @@ class WatchLoadUpdateReq(BaseReq):
|
|
|
1455
1473
|
loads: List[GetLoadReqOutput]
|
|
1456
1474
|
|
|
1457
1475
|
|
|
1476
|
+
@dataclass
|
|
1477
|
+
class SetInjectDumpMetadataReqInput(BaseReq):
|
|
1478
|
+
dump_metadata: Dict[str, Any]
|
|
1479
|
+
|
|
1480
|
+
|
|
1481
|
+
@dataclass
|
|
1482
|
+
class SetInjectDumpMetadataReqOutput(BaseReq):
|
|
1483
|
+
success: bool
|
|
1484
|
+
|
|
1485
|
+
|
|
1458
1486
|
@dataclass
|
|
1459
1487
|
class LazyDumpTensorsReqInput(BaseReq):
|
|
1460
1488
|
pass
|
|
@@ -1486,6 +1514,3 @@ def _check_all_req_types():
|
|
|
1486
1514
|
raise ValueError(
|
|
1487
1515
|
f"{name} is a subclass of BaseReq but not follow the naming convention."
|
|
1488
1516
|
)
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
_check_all_req_types()
|
|
@@ -13,7 +13,12 @@ from __future__ import annotations
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
# ==============================================================================
|
|
16
|
-
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
Mixin classes and utils for multi-http-worker mode
|
|
19
|
+
This file uses multiple processes to handle requests and tokenization, reducing the overhead of python and http server.
|
|
20
|
+
"""
|
|
21
|
+
|
|
17
22
|
import asyncio
|
|
18
23
|
import logging
|
|
19
24
|
import multiprocessing as multiprocessing
|
|
@@ -329,6 +334,11 @@ def _handle_output_by_index(output, i):
|
|
|
329
334
|
),
|
|
330
335
|
placeholder_tokens_idx=None,
|
|
331
336
|
placeholder_tokens_val=None,
|
|
337
|
+
retraction_counts=(
|
|
338
|
+
[output.retraction_counts[i]]
|
|
339
|
+
if len(output.retraction_counts) > i
|
|
340
|
+
else None
|
|
341
|
+
),
|
|
332
342
|
token_steps=([output.token_steps[i]] if output.token_steps else None),
|
|
333
343
|
)
|
|
334
344
|
elif isinstance(output, BatchMultimodalOutput):
|
|
@@ -566,3 +576,14 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
|
|
|
566
576
|
logger.warning(
|
|
567
577
|
"uvicorn.supervisors.multiprocess not found, skipping monkey patch"
|
|
568
578
|
)
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
class SenderWrapper:
|
|
582
|
+
def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket):
|
|
583
|
+
self.port_args = port_args
|
|
584
|
+
self.send_to_scheduler = send_to_scheduler
|
|
585
|
+
|
|
586
|
+
def send_pyobj(self, obj):
|
|
587
|
+
if isinstance(obj, BaseReq):
|
|
588
|
+
obj.http_worker_ipc = self.port_args.tokenizer_ipc_name
|
|
589
|
+
self.send_to_scheduler.send_pyobj(obj)
|