sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +12 -1
- sglang/srt/conversation.py +35 -1
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/layers/communicator.py +3 -1
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
- sglang/srt/layers/moe/ep_moe/layer.py +140 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +8 -1
- sglang/srt/managers/mm_utils.py +4 -2
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +17 -5
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +113 -63
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/models/deepseek_v2.py +16 -2
- sglang/srt/models/mllama4.py +360 -79
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +62 -60
- sglang/srt/server_args.py +15 -0
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +37 -17
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -34,10 +34,11 @@ import torch
|
|
34
34
|
import torch.distributed as dist
|
35
35
|
import triton
|
36
36
|
import triton.language as tl
|
37
|
+
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
37
38
|
|
38
39
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
39
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.utils import
|
41
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
41
42
|
|
42
43
|
logger = logging.getLogger(__name__)
|
43
44
|
|
@@ -150,13 +151,16 @@ class KVCache(abc.ABC):
|
|
150
151
|
) -> None:
|
151
152
|
raise NotImplementedError()
|
152
153
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
154
|
+
@abc.abstractmethod
|
155
|
+
def load_from_host_per_layer(
|
156
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
157
|
+
):
|
157
158
|
raise NotImplementedError()
|
158
159
|
|
159
|
-
|
160
|
+
@abc.abstractmethod
|
161
|
+
def backup_to_host_all_layer(
|
162
|
+
self, host_pool, host_indices, device_indices, io_backend
|
163
|
+
):
|
160
164
|
raise NotImplementedError()
|
161
165
|
|
162
166
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
@@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache):
|
|
247
251
|
)
|
248
252
|
for _ in range(self.layer_num)
|
249
253
|
]
|
250
|
-
|
254
|
+
self.token_stride = self.head_num * self.head_dim
|
251
255
|
self.data_ptrs = torch.tensor(
|
252
256
|
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
253
257
|
dtype=torch.uint64,
|
@@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache):
|
|
281
285
|
# layer_num x [seq_len, head_num, head_dim]
|
282
286
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
283
287
|
kv_data_ptrs = [
|
284
|
-
self.
|
288
|
+
self._get_key_buffer(i).data_ptr()
|
285
289
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
286
290
|
] + [
|
287
|
-
self.
|
291
|
+
self._get_value_buffer(i).data_ptr()
|
288
292
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
289
293
|
]
|
290
294
|
kv_data_lens = [
|
291
|
-
self.
|
295
|
+
self._get_key_buffer(i).nbytes
|
292
296
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
293
297
|
] + [
|
294
|
-
self.
|
298
|
+
self._get_value_buffer(i).nbytes
|
295
299
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
296
300
|
]
|
297
301
|
kv_item_lens = [
|
298
|
-
self.
|
302
|
+
self._get_key_buffer(i)[0].nbytes * self.page_size
|
299
303
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
300
304
|
] + [
|
301
|
-
self.
|
305
|
+
self._get_value_buffer(i)[0].nbytes * self.page_size
|
302
306
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
303
307
|
]
|
304
308
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
@@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache):
|
|
341
345
|
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
342
346
|
torch.cuda.synchronize()
|
343
347
|
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
348
|
+
def load_from_host_per_layer(
|
349
|
+
self,
|
350
|
+
host_pool,
|
351
|
+
host_indices,
|
352
|
+
device_indices,
|
353
|
+
layer_id,
|
354
|
+
io_backend,
|
355
|
+
):
|
356
|
+
transfer_kv_per_layer(
|
357
|
+
src_k=host_pool.k_buffer[layer_id],
|
358
|
+
dst_k=self.k_buffer[layer_id],
|
359
|
+
src_v=host_pool.v_buffer[layer_id],
|
360
|
+
dst_v=self.v_buffer[layer_id],
|
361
|
+
src_indices=host_indices,
|
362
|
+
dst_indices=device_indices,
|
363
|
+
io_backend=io_backend,
|
364
|
+
page_size=self.page_size,
|
365
|
+
item_size=self.token_stride,
|
352
366
|
)
|
353
|
-
return flatten
|
354
|
-
|
355
|
-
@debug_timing
|
356
|
-
def transfer(self, indices, flat_data):
|
357
|
-
# transfer prepared data from host to device
|
358
|
-
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
359
|
-
k_data, v_data = flat_data[0], flat_data[1]
|
360
|
-
for i in range(self.layer_num):
|
361
|
-
self.k_buffer[i][indices] = k_data[i]
|
362
|
-
self.v_buffer[i][indices] = v_data[i]
|
363
|
-
|
364
|
-
def transfer_per_layer(self, indices, flat_data, layer_id):
|
365
|
-
# transfer prepared data from host to device
|
366
|
-
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
367
|
-
k_data, v_data = flat_data[0], flat_data[1]
|
368
|
-
self.k_buffer[layer_id - self.start_layer][indices] = k_data
|
369
|
-
self.v_buffer[layer_id - self.start_layer][indices] = v_data
|
370
367
|
|
371
|
-
def
|
372
|
-
|
373
|
-
|
368
|
+
def backup_to_host_all_layer(
|
369
|
+
self, host_pool, host_indices, device_indices, io_backend
|
370
|
+
):
|
371
|
+
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
372
|
+
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
373
|
+
if layer_id - self.start_layer >= len(host_pool.k_buffer):
|
374
|
+
raise ValueError(
|
375
|
+
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
376
|
+
)
|
377
|
+
transfer_kv_per_layer(
|
378
|
+
src_k=self.k_buffer[layer_id],
|
379
|
+
dst_k=host_pool.k_buffer[layer_id],
|
380
|
+
src_v=self.v_buffer[layer_id],
|
381
|
+
dst_v=host_pool.v_buffer[layer_id],
|
382
|
+
src_indices=device_indices,
|
383
|
+
dst_indices=host_indices,
|
384
|
+
io_backend=io_backend,
|
385
|
+
page_size=self.page_size,
|
386
|
+
item_size=self.token_stride,
|
387
|
+
)
|
374
388
|
|
389
|
+
def _get_key_buffer(self, layer_id: int):
|
390
|
+
# for internal use of referencing
|
375
391
|
if self.store_dtype != self.dtype:
|
376
392
|
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
377
393
|
return self.k_buffer[layer_id - self.start_layer]
|
378
394
|
|
379
|
-
def
|
395
|
+
def get_key_buffer(self, layer_id: int):
|
396
|
+
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
|
397
|
+
# it is supposed to be used only by attention backend not for information purpose
|
398
|
+
# same applies to get_value_buffer and get_kv_buffer
|
380
399
|
if self.layer_transfer_counter is not None:
|
381
400
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
382
401
|
|
402
|
+
return self._get_key_buffer(layer_id)
|
403
|
+
|
404
|
+
def _get_value_buffer(self, layer_id: int):
|
405
|
+
# for internal use of referencing
|
383
406
|
if self.store_dtype != self.dtype:
|
384
407
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
385
408
|
return self.v_buffer[layer_id - self.start_layer]
|
386
409
|
|
410
|
+
def get_value_buffer(self, layer_id: int):
|
411
|
+
if self.layer_transfer_counter is not None:
|
412
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
413
|
+
return self._get_value_buffer(layer_id)
|
414
|
+
|
387
415
|
def get_kv_buffer(self, layer_id: int):
|
388
416
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
389
417
|
|
@@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache):
|
|
761
789
|
for _ in range(layer_num)
|
762
790
|
]
|
763
791
|
|
792
|
+
self.token_stride = kv_lora_rank + qk_rope_head_dim
|
764
793
|
self.layer_transfer_counter = None
|
765
794
|
|
766
795
|
kv_size = self.get_kv_size_bytes()
|
@@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache):
|
|
846
875
|
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
847
876
|
)
|
848
877
|
|
849
|
-
def
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
self.
|
878
|
+
def load_from_host_per_layer(
|
879
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
880
|
+
):
|
881
|
+
transfer_kv_per_layer_mla(
|
882
|
+
src=host_pool.kv_buffer[layer_id],
|
883
|
+
dst=self.kv_buffer[layer_id],
|
884
|
+
src_indices=host_indices,
|
885
|
+
dst_indices=device_indices,
|
886
|
+
io_backend=io_backend,
|
887
|
+
page_size=self.page_size,
|
888
|
+
item_size=self.token_stride,
|
889
|
+
)
|
859
890
|
|
860
|
-
def
|
861
|
-
|
862
|
-
|
863
|
-
|
891
|
+
def backup_to_host_all_layer(
|
892
|
+
self, host_pool, host_indices, device_indices, io_backend
|
893
|
+
):
|
894
|
+
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
895
|
+
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
896
|
+
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
|
897
|
+
raise ValueError(
|
898
|
+
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
899
|
+
)
|
900
|
+
transfer_kv_per_layer_mla(
|
901
|
+
src=self.kv_buffer[layer_id],
|
902
|
+
dst=host_pool.kv_buffer[layer_id],
|
903
|
+
src_indices=device_indices,
|
904
|
+
dst_indices=host_indices,
|
905
|
+
io_backend=io_backend,
|
906
|
+
page_size=self.page_size,
|
907
|
+
item_size=self.token_stride,
|
908
|
+
)
|
864
909
|
|
865
910
|
def get_cpu_copy(self, indices):
|
866
911
|
torch.cuda.synchronize()
|
@@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
1046
1091
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
1047
1092
|
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
1048
1093
|
|
1049
|
-
def
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1094
|
+
def load_from_host_per_layer(
|
1095
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
1096
|
+
):
|
1097
|
+
raise NotImplementedError(
|
1098
|
+
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1099
|
+
)
|
1054
1100
|
|
1055
|
-
def
|
1056
|
-
|
1101
|
+
def backup_to_host_all_layer(
|
1102
|
+
self, host_pool, host_indices, device_indices, io_backend
|
1103
|
+
):
|
1104
|
+
raise NotImplementedError(
|
1105
|
+
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1106
|
+
)
|
1057
1107
|
|
1058
1108
|
|
1059
1109
|
@triton.jit
|
@@ -8,7 +8,6 @@ import psutil
|
|
8
8
|
import torch
|
9
9
|
|
10
10
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
-
from sglang.srt.utils import debug_timing
|
12
11
|
|
13
12
|
logger = logging.getLogger(__name__)
|
14
13
|
|
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
|
|
99
98
|
def init_kv_buffer(self):
|
100
99
|
raise NotImplementedError()
|
101
100
|
|
102
|
-
@abc.abstractmethod
|
103
|
-
def transfer(self, indices, flat_data):
|
104
|
-
raise NotImplementedError()
|
105
|
-
|
106
|
-
@abc.abstractmethod
|
107
|
-
def get_flat_data(self, indices):
|
108
|
-
raise NotImplementedError()
|
109
|
-
|
110
|
-
@abc.abstractmethod
|
111
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
112
|
-
raise NotImplementedError()
|
113
|
-
|
114
|
-
@abc.abstractmethod
|
115
|
-
def assign_flat_data(self, indices, flat_data):
|
116
|
-
raise NotImplementedError()
|
117
|
-
|
118
101
|
@synchronized()
|
119
102
|
def clear(self):
|
120
103
|
# Initialize memory states and tracking structures.
|
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
243
226
|
pin_memory=self.pin_memory,
|
244
227
|
)
|
245
228
|
|
246
|
-
@
|
247
|
-
def
|
248
|
-
|
249
|
-
self.kv_buffer[:, :, indices] = flat_data.to(
|
250
|
-
device=self.device, non_blocking=False
|
251
|
-
)
|
229
|
+
@property
|
230
|
+
def k_buffer(self):
|
231
|
+
return self.kv_buffer[0]
|
252
232
|
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
257
|
-
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
258
|
-
|
259
|
-
def assign_flat_data(self, indices, flat_data):
|
260
|
-
self.kv_buffer[:, :, indices] = flat_data
|
261
|
-
|
262
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
263
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
264
|
-
for i in range(len(device_indices_cpu)):
|
265
|
-
h_index = host_indices[i * self.page_size]
|
266
|
-
d_index = device_indices_cpu[i]
|
267
|
-
for j in range(self.layer_num):
|
268
|
-
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
269
|
-
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
270
|
-
non_blocking=True,
|
271
|
-
)
|
272
|
-
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
273
|
-
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
274
|
-
non_blocking=True,
|
275
|
-
)
|
276
|
-
|
277
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
278
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
279
|
-
for i in range(len(device_indices_cpu)):
|
280
|
-
h_index = host_indices[i * self.page_size]
|
281
|
-
d_index = device_indices_cpu[i]
|
282
|
-
device_pool.k_buffer[layer_id - self.start_layer][
|
283
|
-
d_index : d_index + self.page_size
|
284
|
-
].copy_(
|
285
|
-
self.kv_buffer[
|
286
|
-
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
287
|
-
],
|
288
|
-
non_blocking=True,
|
289
|
-
)
|
290
|
-
device_pool.v_buffer[layer_id - self.start_layer][
|
291
|
-
d_index : d_index + self.page_size
|
292
|
-
].copy_(
|
293
|
-
self.kv_buffer[
|
294
|
-
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
295
|
-
],
|
296
|
-
non_blocking=True,
|
297
|
-
)
|
233
|
+
@property
|
234
|
+
def v_buffer(self):
|
235
|
+
return self.kv_buffer[1]
|
298
236
|
|
299
237
|
|
300
238
|
class MLATokenToKVPoolHost(HostKVCache):
|
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
337
275
|
device=self.device,
|
338
276
|
pin_memory=self.pin_memory,
|
339
277
|
)
|
340
|
-
|
341
|
-
@debug_timing
|
342
|
-
def transfer(self, indices, flat_data):
|
343
|
-
# backup prepared data from device to host
|
344
|
-
self.kv_buffer[:, indices] = flat_data.to(
|
345
|
-
device=self.device, non_blocking=False
|
346
|
-
)
|
347
|
-
|
348
|
-
def get_flat_data(self, indices):
|
349
|
-
return self.kv_buffer[:, indices]
|
350
|
-
|
351
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
352
|
-
return self.kv_buffer[layer_id - self.start_layer, indices]
|
353
|
-
|
354
|
-
def assign_flat_data(self, indices, flat_data):
|
355
|
-
self.kv_buffer[:, indices] = flat_data
|
356
|
-
|
357
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
358
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
359
|
-
for i in range(len(device_indices_cpu)):
|
360
|
-
h_index = host_indices[i * self.page_size]
|
361
|
-
d_index = device_indices_cpu[i]
|
362
|
-
for j in range(self.layer_num):
|
363
|
-
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
364
|
-
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
365
|
-
non_blocking=True,
|
366
|
-
)
|
367
|
-
|
368
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
369
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
370
|
-
for i in range(len(device_indices_cpu)):
|
371
|
-
h_index = host_indices[i * self.page_size]
|
372
|
-
d_index = device_indices_cpu[i]
|
373
|
-
device_pool.kv_buffer[layer_id - self.start_layer][
|
374
|
-
d_index : d_index + self.page_size
|
375
|
-
].copy_(
|
376
|
-
self.kv_buffer[
|
377
|
-
layer_id - self.start_layer, h_index : h_index + self.page_size
|
378
|
-
],
|
379
|
-
non_blocking=True,
|
380
|
-
)
|
@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
|
|
196
196
|
|
197
197
|
if self.page_size != 1:
|
198
198
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
199
|
-
page_aligned_kv_indices = kv_indices[:page_aligned_len].
|
199
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
200
|
+
dtype=torch.int64, copy=True
|
201
|
+
)
|
200
202
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
201
203
|
else:
|
202
204
|
page_aligned_len = len(kv_indices)
|
203
|
-
page_aligned_kv_indices = kv_indices.
|
205
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
204
206
|
|
205
207
|
# Radix Cache takes one ref in memory pool
|
206
208
|
new_prefix_len = self.insert(
|
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
|
|
226
228
|
|
227
229
|
if self.page_size != 1:
|
228
230
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
229
|
-
page_aligned_kv_indices = kv_indices[:page_aligned_len].
|
231
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
232
|
+
dtype=torch.int64, copy=True
|
233
|
+
)
|
230
234
|
else:
|
231
235
|
page_aligned_len = len(kv_indices)
|
232
|
-
page_aligned_kv_indices = kv_indices.
|
236
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
233
237
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
234
238
|
|
235
239
|
# Radix Cache takes one ref in memory pool
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -210,8 +210,10 @@ class MoEGate(nn.Module):
|
|
210
210
|
self,
|
211
211
|
config,
|
212
212
|
prefix: str = "",
|
213
|
+
is_nextn: bool = False,
|
213
214
|
):
|
214
215
|
super().__init__()
|
216
|
+
self.is_nextn = is_nextn
|
215
217
|
self.weight = nn.Parameter(
|
216
218
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
217
219
|
)
|
@@ -233,8 +235,10 @@ class MoEGate(nn.Module):
|
|
233
235
|
True, # is_vnni
|
234
236
|
)
|
235
237
|
|
238
|
+
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
|
236
239
|
if (
|
237
240
|
_is_cuda
|
241
|
+
and not self.is_nextn
|
238
242
|
and hidden_states.shape[0] < 4
|
239
243
|
and hidden_states.shape[1] == 7168
|
240
244
|
and self.weight.shape[0] == 256
|
@@ -258,6 +262,7 @@ class DeepseekV2MoE(nn.Module):
|
|
258
262
|
quant_config: Optional[QuantizationConfig] = None,
|
259
263
|
prefix: str = "",
|
260
264
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
265
|
+
is_nextn: bool = False,
|
261
266
|
):
|
262
267
|
super().__init__()
|
263
268
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -284,7 +289,9 @@ class DeepseekV2MoE(nn.Module):
|
|
284
289
|
"Only silu is supported for now."
|
285
290
|
)
|
286
291
|
|
287
|
-
self.gate = MoEGate(
|
292
|
+
self.gate = MoEGate(
|
293
|
+
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
294
|
+
)
|
288
295
|
|
289
296
|
self.experts = get_moe_impl_class()(
|
290
297
|
num_experts=config.n_routed_experts
|
@@ -1776,6 +1783,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1776
1783
|
prefix=add_prefix("mlp", prefix),
|
1777
1784
|
layer_id=self.layer_id,
|
1778
1785
|
alt_stream=alt_stream,
|
1786
|
+
is_nextn=is_nextn,
|
1779
1787
|
)
|
1780
1788
|
else:
|
1781
1789
|
if enable_moe_dense_fully_dp():
|
@@ -1930,7 +1938,7 @@ class DeepseekV2Model(nn.Module):
|
|
1930
1938
|
self.embed_tokens = VocabParallelEmbedding(
|
1931
1939
|
config.vocab_size,
|
1932
1940
|
config.hidden_size,
|
1933
|
-
|
1941
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
1934
1942
|
)
|
1935
1943
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
1936
1944
|
self.layers = nn.ModuleList(
|
@@ -2355,6 +2363,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2355
2363
|
ckpt_up_proj_name="up_proj",
|
2356
2364
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
2357
2365
|
)
|
2366
|
+
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
2367
|
+
expert_params_mapping += (
|
2368
|
+
get_moe_impl_class().make_expert_input_scale_params_mapping(
|
2369
|
+
num_experts=self.config.n_routed_experts
|
2370
|
+
)
|
2371
|
+
)
|
2358
2372
|
|
2359
2373
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
2360
2374
|
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|