sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,81 +0,0 @@
|
|
1
|
-
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
|
2
|
-
from sglang.srt.distributed.utils import divide
|
3
|
-
|
4
|
-
|
5
|
-
class MambaStateShapeCalculator:
|
6
|
-
|
7
|
-
@classmethod
|
8
|
-
def linear_attention_state_shape(
|
9
|
-
cls,
|
10
|
-
num_heads: int,
|
11
|
-
tp_size: int,
|
12
|
-
head_dim: int,
|
13
|
-
) -> tuple[tuple[int, int, int], ...]:
|
14
|
-
|
15
|
-
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
16
|
-
return (state_shape,)
|
17
|
-
|
18
|
-
@classmethod
|
19
|
-
def mamba1_state_shape(
|
20
|
-
cls,
|
21
|
-
tp_world_size: int,
|
22
|
-
intermediate_size: int,
|
23
|
-
state_size: int,
|
24
|
-
conv_kernel: int,
|
25
|
-
) -> tuple[tuple[int, int], tuple[int, int]]:
|
26
|
-
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
|
27
|
-
|
28
|
-
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
|
29
|
-
|
30
|
-
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
31
|
-
|
32
|
-
return conv_state_shape, temporal_state_shape
|
33
|
-
|
34
|
-
@classmethod
|
35
|
-
def mamba2_state_shape(
|
36
|
-
cls,
|
37
|
-
tp_world_size: int,
|
38
|
-
intermediate_size: int,
|
39
|
-
n_groups: int,
|
40
|
-
num_heads: int,
|
41
|
-
head_dim: int,
|
42
|
-
state_size: int,
|
43
|
-
conv_kernel: int,
|
44
|
-
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
45
|
-
# if n_groups is not divisible by world_size, need to extend the shards
|
46
|
-
# to ensure all groups needed by a head is sharded along with it
|
47
|
-
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
|
48
|
-
# heads and n_groups are TP-ed
|
49
|
-
conv_dim = intermediate_size + 2 * n_groups * state_size
|
50
|
-
|
51
|
-
# contiguous along 'dim' axis
|
52
|
-
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
53
|
-
|
54
|
-
# These are not TP-ed as they depend on A, dt_bias, D
|
55
|
-
# - they are typically small
|
56
|
-
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
57
|
-
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
58
|
-
return conv_state_shape, temporal_state_shape
|
59
|
-
|
60
|
-
@classmethod
|
61
|
-
def short_conv_state_shape(
|
62
|
-
cls,
|
63
|
-
tp_world_size: int,
|
64
|
-
intermediate_size: int,
|
65
|
-
conv_kernel: int,
|
66
|
-
) -> tuple[tuple[int, int]]:
|
67
|
-
conv_dim = divide(intermediate_size, tp_world_size)
|
68
|
-
conv_state_shape = (conv_kernel - 1, conv_dim)
|
69
|
-
return (conv_state_shape,)
|
70
|
-
|
71
|
-
@classmethod
|
72
|
-
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
73
|
-
"""Compute the increase in group numbers to account for
|
74
|
-
replication in order to accompany the head shards."""
|
75
|
-
|
76
|
-
# in the case ngoups % tp_size == 0, this will be zero
|
77
|
-
if ngroups % tp_size == 0:
|
78
|
-
return 0
|
79
|
-
|
80
|
-
# for n_groups == 1, this is exactly tp_size - n_groups
|
81
|
-
return tp_size - ngroups
|
@@ -1,311 +0,0 @@
|
|
1
|
-
# Copyright 2023-2024 SGLang Team
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
# ==============================================================================
|
14
|
-
"""A tensor parallel worker."""
|
15
|
-
from __future__ import annotations
|
16
|
-
|
17
|
-
import dataclasses
|
18
|
-
import logging
|
19
|
-
import signal
|
20
|
-
import threading
|
21
|
-
from queue import Queue
|
22
|
-
from typing import TYPE_CHECKING, List, Optional, Tuple
|
23
|
-
|
24
|
-
import psutil
|
25
|
-
import torch
|
26
|
-
|
27
|
-
from sglang.srt.managers.io_struct import (
|
28
|
-
DestroyWeightsUpdateGroupReqInput,
|
29
|
-
GetWeightsByNameReqInput,
|
30
|
-
InitWeightsSendGroupForRemoteInstanceReqInput,
|
31
|
-
InitWeightsUpdateGroupReqInput,
|
32
|
-
LoadLoRAAdapterReqInput,
|
33
|
-
SendWeightsToRemoteInstanceReqInput,
|
34
|
-
UnloadLoRAAdapterReqInput,
|
35
|
-
UpdateWeightFromDiskReqInput,
|
36
|
-
UpdateWeightsFromDistributedReqInput,
|
37
|
-
UpdateWeightsFromTensorReqInput,
|
38
|
-
)
|
39
|
-
from sglang.srt.managers.overlap_utils import FutureMap
|
40
|
-
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
41
|
-
from sglang.srt.managers.tp_worker import TpModelWorker
|
42
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
|
43
|
-
from sglang.srt.server_args import ServerArgs
|
44
|
-
from sglang.srt.utils import DynamicGradMode
|
45
|
-
from sglang.utils import get_exception_traceback
|
46
|
-
|
47
|
-
if TYPE_CHECKING:
|
48
|
-
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
49
|
-
|
50
|
-
logger = logging.getLogger(__name__)
|
51
|
-
|
52
|
-
|
53
|
-
class TpModelWorkerClient:
|
54
|
-
"""A tensor parallel model worker."""
|
55
|
-
|
56
|
-
def __init__(
|
57
|
-
self,
|
58
|
-
server_args: ServerArgs,
|
59
|
-
gpu_id: int,
|
60
|
-
tp_rank: int,
|
61
|
-
moe_ep_rank: int,
|
62
|
-
pp_rank: int,
|
63
|
-
dp_rank: Optional[int],
|
64
|
-
nccl_port: int,
|
65
|
-
):
|
66
|
-
# Load the model
|
67
|
-
self.worker = TpModelWorker(
|
68
|
-
server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port
|
69
|
-
)
|
70
|
-
self.max_running_requests = self.worker.max_running_requests
|
71
|
-
self.device = self.worker.device
|
72
|
-
self.gpu_id = gpu_id
|
73
|
-
|
74
|
-
# Init future mappings
|
75
|
-
self.future_map = FutureMap(self.max_running_requests, self.device)
|
76
|
-
|
77
|
-
# Launch threads
|
78
|
-
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
79
|
-
self.output_queue = Queue()
|
80
|
-
self.forward_stream = torch.get_device_module(self.device).Stream()
|
81
|
-
self.forward_thread = threading.Thread(
|
82
|
-
target=self.forward_thread_func,
|
83
|
-
)
|
84
|
-
self.forward_thread.start()
|
85
|
-
self.parent_process = psutil.Process().parent()
|
86
|
-
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
87
|
-
if self.device == "cpu":
|
88
|
-
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
89
|
-
|
90
|
-
self.hicache_layer_transfer_counter = None
|
91
|
-
|
92
|
-
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
93
|
-
self.hicache_layer_transfer_counter = counter
|
94
|
-
|
95
|
-
def get_worker_info(self):
|
96
|
-
return self.worker.get_worker_info()
|
97
|
-
|
98
|
-
def get_tokens_per_layer_info(self):
|
99
|
-
return self.worker.get_tokens_per_layer_info()
|
100
|
-
|
101
|
-
@property
|
102
|
-
def sliding_window_size(self) -> Optional[int]:
|
103
|
-
return self.worker.sliding_window_size
|
104
|
-
|
105
|
-
@property
|
106
|
-
def is_hybrid(self) -> bool:
|
107
|
-
return self.worker.is_hybrid
|
108
|
-
|
109
|
-
def get_pad_input_ids_func(self):
|
110
|
-
return self.worker.get_pad_input_ids_func()
|
111
|
-
|
112
|
-
def get_tp_group(self):
|
113
|
-
return self.worker.get_tp_group()
|
114
|
-
|
115
|
-
def get_attention_tp_group(self):
|
116
|
-
return self.worker.get_attention_tp_group()
|
117
|
-
|
118
|
-
def get_attention_tp_cpu_group(self):
|
119
|
-
return self.worker.get_attention_tp_cpu_group()
|
120
|
-
|
121
|
-
def get_memory_pool(self):
|
122
|
-
return (
|
123
|
-
self.worker.model_runner.req_to_token_pool,
|
124
|
-
self.worker.model_runner.token_to_kv_pool_allocator,
|
125
|
-
)
|
126
|
-
|
127
|
-
def get_kv_cache(self):
|
128
|
-
return self.worker.model_runner.token_to_kv_pool
|
129
|
-
|
130
|
-
def forward_thread_func(self):
|
131
|
-
try:
|
132
|
-
with torch.get_device_module(self.device).stream(self.forward_stream):
|
133
|
-
self.forward_thread_func_()
|
134
|
-
except Exception:
|
135
|
-
traceback = get_exception_traceback()
|
136
|
-
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
137
|
-
self.parent_process.send_signal(signal.SIGQUIT)
|
138
|
-
|
139
|
-
@DynamicGradMode()
|
140
|
-
def forward_thread_func_(self):
|
141
|
-
batch_pt = 0
|
142
|
-
batch_lists: List = [None] * 2
|
143
|
-
|
144
|
-
while True:
|
145
|
-
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
|
146
|
-
if not model_worker_batch:
|
147
|
-
break
|
148
|
-
|
149
|
-
sync_event.wait()
|
150
|
-
|
151
|
-
# Keep a reference of model_worker_batch by storing it into a list.
|
152
|
-
# Otherwise, the tensor members of model_worker_batch will be released
|
153
|
-
# by pytorch and cause CUDA illegal memory access errors.
|
154
|
-
batch_lists[batch_pt % 2] = model_worker_batch
|
155
|
-
batch_pt += 1
|
156
|
-
|
157
|
-
# Create event
|
158
|
-
copy_done = torch.get_device_module(self.device).Event()
|
159
|
-
|
160
|
-
# Resolve future tokens in the input
|
161
|
-
self.future_map.resolve_future(model_worker_batch)
|
162
|
-
|
163
|
-
# Run forward
|
164
|
-
forward_batch_output = self.worker.forward_batch_generation(
|
165
|
-
model_worker_batch,
|
166
|
-
model_worker_batch.launch_done,
|
167
|
-
)
|
168
|
-
|
169
|
-
logits_output, next_token_ids, can_run_cuda_graph = (
|
170
|
-
forward_batch_output.logits_output,
|
171
|
-
forward_batch_output.next_token_ids,
|
172
|
-
forward_batch_output.can_run_cuda_graph,
|
173
|
-
)
|
174
|
-
|
175
|
-
# Update the future token ids map
|
176
|
-
bs = len(model_worker_batch.seq_lens)
|
177
|
-
if model_worker_batch.is_prefill_only:
|
178
|
-
# For prefill-only requests, create dummy token IDs on CPU
|
179
|
-
next_token_ids = torch.zeros(bs, dtype=torch.long)
|
180
|
-
|
181
|
-
# store the future indices into future map
|
182
|
-
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
|
183
|
-
|
184
|
-
# Copy results to the CPU
|
185
|
-
if model_worker_batch.return_logprob:
|
186
|
-
if logits_output.next_token_logprobs is not None:
|
187
|
-
logits_output.next_token_logprobs = (
|
188
|
-
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
|
189
|
-
)
|
190
|
-
if logits_output.input_token_logprobs is not None:
|
191
|
-
logits_output.input_token_logprobs = (
|
192
|
-
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
193
|
-
)
|
194
|
-
if logits_output.hidden_states is not None:
|
195
|
-
logits_output.hidden_states = logits_output.hidden_states.to(
|
196
|
-
"cpu", non_blocking=True
|
197
|
-
)
|
198
|
-
# Only copy to CPU if not already on CPU
|
199
|
-
if next_token_ids.device.type != "cpu":
|
200
|
-
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
201
|
-
copy_done.record()
|
202
|
-
|
203
|
-
self.output_queue.put(
|
204
|
-
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
|
205
|
-
)
|
206
|
-
|
207
|
-
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
208
|
-
"""
|
209
|
-
This function is called to resolve the last batch result and
|
210
|
-
wait for the current batch to be launched. Used in overlap mode.
|
211
|
-
"""
|
212
|
-
copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
|
213
|
-
self.output_queue.get()
|
214
|
-
)
|
215
|
-
|
216
|
-
if launch_done is not None:
|
217
|
-
launch_done.wait()
|
218
|
-
copy_done.synchronize()
|
219
|
-
|
220
|
-
if logits_output.next_token_logprobs is not None:
|
221
|
-
logits_output.next_token_logprobs = (
|
222
|
-
logits_output.next_token_logprobs.tolist()
|
223
|
-
)
|
224
|
-
if logits_output.input_token_logprobs is not None:
|
225
|
-
logits_output.input_token_logprobs = tuple(
|
226
|
-
logits_output.input_token_logprobs.tolist()
|
227
|
-
)
|
228
|
-
next_token_ids = next_token_ids.tolist()
|
229
|
-
return logits_output, next_token_ids, can_run_cuda_graph
|
230
|
-
|
231
|
-
def forward_batch_generation(
|
232
|
-
self, model_worker_batch: ModelWorkerBatch
|
233
|
-
) -> ForwardBatchOutput:
|
234
|
-
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
235
|
-
sampling_info = model_worker_batch.sampling_info
|
236
|
-
sampling_info.update_penalties()
|
237
|
-
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
238
|
-
sampling_info,
|
239
|
-
sampling_info_done=threading.Event(),
|
240
|
-
penalizer_orchestrator=None,
|
241
|
-
)
|
242
|
-
|
243
|
-
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
244
|
-
sync_event = torch.get_device_module(self.device).Event()
|
245
|
-
sync_event.record(self.scheduler_stream)
|
246
|
-
|
247
|
-
# Push a new batch to the queue
|
248
|
-
bs = len(model_worker_batch.seq_lens)
|
249
|
-
cur_future_map_ct = self.future_map.update_ct(bs)
|
250
|
-
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
|
251
|
-
|
252
|
-
# get this forward batch's future token ids
|
253
|
-
future_next_token_ids = self.future_map.update_next_future(
|
254
|
-
cur_future_map_ct, bs
|
255
|
-
)
|
256
|
-
return ForwardBatchOutput(
|
257
|
-
next_token_ids=future_next_token_ids,
|
258
|
-
can_run_cuda_graph=False,
|
259
|
-
)
|
260
|
-
|
261
|
-
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
262
|
-
success, message = self.worker.update_weights_from_disk(recv_req)
|
263
|
-
return success, message
|
264
|
-
|
265
|
-
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
266
|
-
success, message = self.worker.init_weights_update_group(recv_req)
|
267
|
-
return success, message
|
268
|
-
|
269
|
-
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
270
|
-
success, message = self.worker.destroy_weights_update_group(recv_req)
|
271
|
-
return success, message
|
272
|
-
|
273
|
-
def init_weights_send_group_for_remote_instance(
|
274
|
-
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
275
|
-
):
|
276
|
-
success, message = self.worker.init_weights_send_group_for_remote_instance(
|
277
|
-
recv_req
|
278
|
-
)
|
279
|
-
return success, message
|
280
|
-
|
281
|
-
def send_weights_to_remote_instance(
|
282
|
-
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
283
|
-
):
|
284
|
-
success, message = self.worker.send_weights_to_remote_instance(recv_req)
|
285
|
-
return success, message
|
286
|
-
|
287
|
-
def update_weights_from_distributed(
|
288
|
-
self, recv_req: UpdateWeightsFromDistributedReqInput
|
289
|
-
):
|
290
|
-
success, message = self.worker.update_weights_from_distributed(recv_req)
|
291
|
-
return success, message
|
292
|
-
|
293
|
-
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
294
|
-
success, message = self.worker.update_weights_from_tensor(recv_req)
|
295
|
-
return success, message
|
296
|
-
|
297
|
-
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
298
|
-
return self.worker.get_weights_by_name(recv_req)
|
299
|
-
|
300
|
-
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
301
|
-
return self.worker.load_lora_adapter(recv_req)
|
302
|
-
|
303
|
-
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
304
|
-
return self.worker.unload_lora_adapter(recv_req)
|
305
|
-
|
306
|
-
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
307
|
-
return self.worker.can_run_lora_batch(lora_ids)
|
308
|
-
|
309
|
-
def __delete__(self):
|
310
|
-
self.input_queue.put((None, None))
|
311
|
-
self.copy_queue.put((None, None, None))
|