sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@ import psutil
|
|
8
8
|
import torch
|
9
9
|
|
10
10
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
-
from sglang.srt.utils import debug_timing
|
12
11
|
|
13
12
|
logger = logging.getLogger(__name__)
|
14
13
|
|
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
|
|
99
98
|
def init_kv_buffer(self):
|
100
99
|
raise NotImplementedError()
|
101
100
|
|
102
|
-
@abc.abstractmethod
|
103
|
-
def transfer(self, indices, flat_data):
|
104
|
-
raise NotImplementedError()
|
105
|
-
|
106
|
-
@abc.abstractmethod
|
107
|
-
def get_flat_data(self, indices):
|
108
|
-
raise NotImplementedError()
|
109
|
-
|
110
|
-
@abc.abstractmethod
|
111
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
112
|
-
raise NotImplementedError()
|
113
|
-
|
114
|
-
@abc.abstractmethod
|
115
|
-
def assign_flat_data(self, indices, flat_data):
|
116
|
-
raise NotImplementedError()
|
117
|
-
|
118
101
|
@synchronized()
|
119
102
|
def clear(self):
|
120
103
|
# Initialize memory states and tracking structures.
|
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
243
226
|
pin_memory=self.pin_memory,
|
244
227
|
)
|
245
228
|
|
246
|
-
@
|
247
|
-
def
|
248
|
-
|
249
|
-
self.kv_buffer[:, :, indices] = flat_data.to(
|
250
|
-
device=self.device, non_blocking=False
|
251
|
-
)
|
229
|
+
@property
|
230
|
+
def k_buffer(self):
|
231
|
+
return self.kv_buffer[0]
|
252
232
|
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
257
|
-
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
258
|
-
|
259
|
-
def assign_flat_data(self, indices, flat_data):
|
260
|
-
self.kv_buffer[:, :, indices] = flat_data
|
261
|
-
|
262
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
263
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
264
|
-
for i in range(len(device_indices_cpu)):
|
265
|
-
h_index = host_indices[i * self.page_size]
|
266
|
-
d_index = device_indices_cpu[i]
|
267
|
-
for j in range(self.layer_num):
|
268
|
-
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
269
|
-
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
270
|
-
non_blocking=True,
|
271
|
-
)
|
272
|
-
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
273
|
-
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
274
|
-
non_blocking=True,
|
275
|
-
)
|
276
|
-
|
277
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
278
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
279
|
-
for i in range(len(device_indices_cpu)):
|
280
|
-
h_index = host_indices[i * self.page_size]
|
281
|
-
d_index = device_indices_cpu[i]
|
282
|
-
device_pool.k_buffer[layer_id - self.start_layer][
|
283
|
-
d_index : d_index + self.page_size
|
284
|
-
].copy_(
|
285
|
-
self.kv_buffer[
|
286
|
-
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
287
|
-
],
|
288
|
-
non_blocking=True,
|
289
|
-
)
|
290
|
-
device_pool.v_buffer[layer_id - self.start_layer][
|
291
|
-
d_index : d_index + self.page_size
|
292
|
-
].copy_(
|
293
|
-
self.kv_buffer[
|
294
|
-
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
295
|
-
],
|
296
|
-
non_blocking=True,
|
297
|
-
)
|
233
|
+
@property
|
234
|
+
def v_buffer(self):
|
235
|
+
return self.kv_buffer[1]
|
298
236
|
|
299
237
|
|
300
238
|
class MLATokenToKVPoolHost(HostKVCache):
|
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
337
275
|
device=self.device,
|
338
276
|
pin_memory=self.pin_memory,
|
339
277
|
)
|
340
|
-
|
341
|
-
@debug_timing
|
342
|
-
def transfer(self, indices, flat_data):
|
343
|
-
# backup prepared data from device to host
|
344
|
-
self.kv_buffer[:, indices] = flat_data.to(
|
345
|
-
device=self.device, non_blocking=False
|
346
|
-
)
|
347
|
-
|
348
|
-
def get_flat_data(self, indices):
|
349
|
-
return self.kv_buffer[:, indices]
|
350
|
-
|
351
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
352
|
-
return self.kv_buffer[layer_id - self.start_layer, indices]
|
353
|
-
|
354
|
-
def assign_flat_data(self, indices, flat_data):
|
355
|
-
self.kv_buffer[:, indices] = flat_data
|
356
|
-
|
357
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
358
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
359
|
-
for i in range(len(device_indices_cpu)):
|
360
|
-
h_index = host_indices[i * self.page_size]
|
361
|
-
d_index = device_indices_cpu[i]
|
362
|
-
for j in range(self.layer_num):
|
363
|
-
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
364
|
-
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
365
|
-
non_blocking=True,
|
366
|
-
)
|
367
|
-
|
368
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
369
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
370
|
-
for i in range(len(device_indices_cpu)):
|
371
|
-
h_index = host_indices[i * self.page_size]
|
372
|
-
d_index = device_indices_cpu[i]
|
373
|
-
device_pool.kv_buffer[layer_id - self.start_layer][
|
374
|
-
d_index : d_index + self.page_size
|
375
|
-
].copy_(
|
376
|
-
self.kv_buffer[
|
377
|
-
layer_id - self.start_layer, h_index : h_index + self.page_size
|
378
|
-
],
|
379
|
-
non_blocking=True,
|
380
|
-
)
|
@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
|
|
196
196
|
|
197
197
|
if self.page_size != 1:
|
198
198
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
199
|
-
page_aligned_kv_indices = kv_indices[:page_aligned_len].
|
199
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
200
|
+
dtype=torch.int64, copy=True
|
201
|
+
)
|
200
202
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
201
203
|
else:
|
202
204
|
page_aligned_len = len(kv_indices)
|
203
|
-
page_aligned_kv_indices = kv_indices.
|
205
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
204
206
|
|
205
207
|
# Radix Cache takes one ref in memory pool
|
206
208
|
new_prefix_len = self.insert(
|
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
|
|
226
228
|
|
227
229
|
if self.page_size != 1:
|
228
230
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
229
|
-
page_aligned_kv_indices = kv_indices[:page_aligned_len].
|
231
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
232
|
+
dtype=torch.int64, copy=True
|
233
|
+
)
|
230
234
|
else:
|
231
235
|
page_aligned_len = len(kv_indices)
|
232
|
-
page_aligned_kv_indices = kv_indices.
|
236
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
233
237
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
234
238
|
|
235
239
|
# Radix Cache takes one ref in memory pool
|
@@ -168,7 +168,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
168
168
|
capture_bs += [model_runner.req_to_token_pool.size]
|
169
169
|
|
170
170
|
if server_args.enable_two_batch_overlap:
|
171
|
-
capture_bs = [bs for bs in capture_bs if bs
|
171
|
+
capture_bs = [bs for bs in capture_bs if bs % 2 == 0]
|
172
172
|
|
173
173
|
if server_args.cuda_graph_max_bs:
|
174
174
|
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
@@ -679,6 +679,7 @@ class CudaGraphRunner:
|
|
679
679
|
forward_mode=self.capture_forward_mode,
|
680
680
|
bs=bs,
|
681
681
|
num_token_non_padded=len(forward_batch.input_ids),
|
682
|
+
spec_info=forward_batch.spec_info,
|
682
683
|
)
|
683
684
|
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
684
685
|
forward_batch.spec_info.custom_mask = self.custom_mask
|
@@ -39,7 +39,12 @@ import triton
|
|
39
39
|
import triton.language as tl
|
40
40
|
|
41
41
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
42
|
-
from sglang.srt.utils import
|
42
|
+
from sglang.srt.utils import (
|
43
|
+
flatten_nested_list,
|
44
|
+
get_compiler_backend,
|
45
|
+
is_npu,
|
46
|
+
support_triton,
|
47
|
+
)
|
43
48
|
|
44
49
|
if TYPE_CHECKING:
|
45
50
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -50,6 +55,8 @@ if TYPE_CHECKING:
|
|
50
55
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
51
56
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
52
57
|
|
58
|
+
_is_npu = is_npu()
|
59
|
+
|
53
60
|
|
54
61
|
class ForwardMode(IntEnum):
|
55
62
|
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
@@ -247,6 +254,7 @@ class ForwardBatch:
|
|
247
254
|
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
248
255
|
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
249
256
|
gathered_buffer: Optional[torch.Tensor] = None
|
257
|
+
is_extend_in_batch: bool = False
|
250
258
|
can_run_dp_cuda_graph: bool = False
|
251
259
|
global_forward_mode: Optional[ForwardMode] = None
|
252
260
|
|
@@ -292,6 +300,7 @@ class ForwardBatch:
|
|
292
300
|
return_logprob=batch.return_logprob,
|
293
301
|
top_logprobs_nums=batch.top_logprobs_nums,
|
294
302
|
token_ids_logprobs=batch.token_ids_logprobs,
|
303
|
+
is_extend_in_batch=batch.is_extend_in_batch,
|
295
304
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
296
305
|
global_forward_mode=batch.global_forward_mode,
|
297
306
|
lora_paths=batch.lora_paths,
|
@@ -352,7 +361,9 @@ class ForwardBatch:
|
|
352
361
|
|
353
362
|
if ret.forward_mode.is_idle():
|
354
363
|
ret.positions = torch.empty((0,), device=device)
|
355
|
-
TboForwardBatchPreparer.prepare(
|
364
|
+
TboForwardBatchPreparer.prepare(
|
365
|
+
ret, is_draft_worker=model_runner.is_draft_worker
|
366
|
+
)
|
356
367
|
return ret
|
357
368
|
|
358
369
|
# Override the positions with spec_info
|
@@ -397,7 +408,9 @@ class ForwardBatch:
|
|
397
408
|
if model_runner.server_args.lora_paths is not None:
|
398
409
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
399
410
|
|
400
|
-
TboForwardBatchPreparer.prepare(
|
411
|
+
TboForwardBatchPreparer.prepare(
|
412
|
+
ret, is_draft_worker=model_runner.is_draft_worker
|
413
|
+
)
|
401
414
|
|
402
415
|
return ret
|
403
416
|
|
@@ -735,7 +748,7 @@ def compute_position_torch(
|
|
735
748
|
return positions.to(torch.int64), extend_start_loc
|
736
749
|
|
737
750
|
|
738
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
751
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
|
739
752
|
def clamp_position(seq_lens):
|
740
753
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
741
754
|
|