sglang 0.4.10__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/ep_moe/layer.py +19 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -2
- sglang/srt/layers/quantization/fp8.py +52 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +35 -35
- sglang/srt/managers/scheduler.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +15 -6
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +8 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/step3_vl.py +0 -3
- sglang/srt/server_args.py +40 -6
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +35 -30
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,21 @@ import psutil
|
|
8
8
|
import torch
|
9
9
|
|
10
10
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
+
from sglang.srt.utils import is_npu
|
12
|
+
|
13
|
+
_is_npu = is_npu()
|
14
|
+
if not _is_npu:
|
15
|
+
from sgl_kernel.kvcacheio import (
|
16
|
+
transfer_kv_all_layer,
|
17
|
+
transfer_kv_all_layer_lf_pf,
|
18
|
+
transfer_kv_all_layer_mla,
|
19
|
+
transfer_kv_all_layer_mla_lf_pf,
|
20
|
+
transfer_kv_direct,
|
21
|
+
transfer_kv_per_layer,
|
22
|
+
transfer_kv_per_layer_mla,
|
23
|
+
transfer_kv_per_layer_mla_pf_lf,
|
24
|
+
transfer_kv_per_layer_pf_lf,
|
25
|
+
)
|
11
26
|
|
12
27
|
logger = logging.getLogger(__name__)
|
13
28
|
|
@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC):
|
|
42
57
|
device_pool: KVCache,
|
43
58
|
host_to_device_ratio: float,
|
44
59
|
host_size: int,
|
60
|
+
page_size: int,
|
61
|
+
layout: str,
|
45
62
|
pin_memory: bool,
|
46
63
|
device: str,
|
47
|
-
page_size: int,
|
48
64
|
):
|
49
65
|
self.device_pool = device_pool
|
50
|
-
self.
|
66
|
+
self.page_size = page_size
|
67
|
+
self.layout = layout
|
51
68
|
self.pin_memory = pin_memory
|
52
69
|
self.device = device
|
53
|
-
|
70
|
+
|
71
|
+
self.dtype = device_pool.store_dtype
|
54
72
|
self.size_per_token = self.get_size_per_token()
|
55
73
|
if host_size > 0:
|
56
74
|
self.size = int(host_size * 1e9 // self.size_per_token)
|
@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC):
|
|
98
116
|
def init_kv_buffer(self):
|
99
117
|
raise NotImplementedError()
|
100
118
|
|
119
|
+
@abc.abstractmethod
|
120
|
+
def load_to_device_per_layer(
|
121
|
+
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
122
|
+
) -> None:
|
123
|
+
"""
|
124
|
+
Load KV data from the host memory pool to the device memory pool for a specific layer.
|
125
|
+
"""
|
126
|
+
raise NotImplementedError()
|
127
|
+
|
128
|
+
@abc.abstractmethod
|
129
|
+
def backup_from_device_all_layer(
|
130
|
+
self, device_pool, host_indices, device_indices, io_backend
|
131
|
+
) -> None:
|
132
|
+
"""
|
133
|
+
Backup KV data from the device memory pool to the host memory pool for all layers.
|
134
|
+
"""
|
135
|
+
raise NotImplementedError()
|
136
|
+
|
101
137
|
@abc.abstractmethod
|
102
138
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
103
139
|
"""
|
@@ -105,6 +141,14 @@ class HostKVCache(abc.ABC):
|
|
105
141
|
"""
|
106
142
|
raise NotImplementedError()
|
107
143
|
|
144
|
+
@abc.abstractmethod
|
145
|
+
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
146
|
+
"""
|
147
|
+
Get a dummy flat data page from the host memory pool.
|
148
|
+
This is used for prefetching or initializing empty pages.
|
149
|
+
"""
|
150
|
+
raise NotImplementedError()
|
151
|
+
|
108
152
|
@abc.abstractmethod
|
109
153
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
110
154
|
"""
|
@@ -230,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
230
274
|
host_to_device_ratio: float,
|
231
275
|
host_size: int,
|
232
276
|
page_size: int,
|
277
|
+
layout: str,
|
233
278
|
pin_memory: bool = True,
|
234
279
|
device: str = "cpu",
|
235
280
|
):
|
236
281
|
super().__init__(
|
237
|
-
device_pool,
|
282
|
+
device_pool,
|
283
|
+
host_to_device_ratio,
|
284
|
+
host_size,
|
285
|
+
page_size,
|
286
|
+
layout,
|
287
|
+
pin_memory,
|
288
|
+
device,
|
289
|
+
)
|
290
|
+
self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
|
291
|
+
self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
|
292
|
+
self.k_data_ptrs = torch.tensor(
|
293
|
+
[x.data_ptr() for x in self.k_data_refs],
|
294
|
+
dtype=torch.uint64,
|
295
|
+
device=self.device_pool.device,
|
296
|
+
)
|
297
|
+
self.v_data_ptrs = torch.tensor(
|
298
|
+
[x.data_ptr() for x in self.v_data_refs],
|
299
|
+
dtype=torch.uint64,
|
300
|
+
device=self.device_pool.device,
|
238
301
|
)
|
239
302
|
|
240
303
|
def get_size_per_token(self):
|
@@ -245,25 +308,156 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
245
308
|
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
246
309
|
|
247
310
|
def init_kv_buffer(self):
|
311
|
+
if self.layout == "layer_first":
|
312
|
+
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
313
|
+
elif self.layout == "page_first":
|
314
|
+
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
315
|
+
else:
|
316
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
317
|
+
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
318
|
+
self.layout_dim = self.token_stride_size * self.layer_num
|
248
319
|
return torch.empty(
|
249
|
-
|
320
|
+
dims,
|
250
321
|
dtype=self.dtype,
|
251
322
|
device=self.device,
|
252
323
|
pin_memory=self.pin_memory,
|
253
324
|
)
|
254
325
|
|
255
|
-
|
326
|
+
@property
|
327
|
+
def k_buffer(self):
|
328
|
+
return self.kv_buffer[0]
|
329
|
+
|
330
|
+
@property
|
331
|
+
def v_buffer(self):
|
332
|
+
return self.kv_buffer[1]
|
333
|
+
|
334
|
+
def load_to_device_per_layer(
|
335
|
+
self,
|
336
|
+
device_pool,
|
337
|
+
host_indices,
|
338
|
+
device_indices,
|
339
|
+
layer_id,
|
340
|
+
io_backend,
|
341
|
+
):
|
342
|
+
if io_backend == "kernel":
|
343
|
+
if self.layout == "layer_first":
|
344
|
+
transfer_kv_per_layer(
|
345
|
+
src_k=self.k_buffer[layer_id],
|
346
|
+
dst_k=device_pool.k_buffer[layer_id],
|
347
|
+
src_v=self.v_buffer[layer_id],
|
348
|
+
dst_v=device_pool.v_buffer[layer_id],
|
349
|
+
src_indices=host_indices,
|
350
|
+
dst_indices=device_indices,
|
351
|
+
item_size=self.token_stride_size,
|
352
|
+
)
|
353
|
+
elif self.layout == "page_first":
|
354
|
+
transfer_kv_per_layer_pf_lf(
|
355
|
+
src_k=self.k_buffer,
|
356
|
+
dst_k=device_pool.k_buffer[layer_id],
|
357
|
+
src_v=self.v_buffer,
|
358
|
+
dst_v=device_pool.v_buffer[layer_id],
|
359
|
+
src_indices=host_indices,
|
360
|
+
dst_indices=device_indices,
|
361
|
+
item_size=self.token_stride_size,
|
362
|
+
src_layout_dim=self.layout_dim,
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
366
|
+
elif io_backend == "direct":
|
367
|
+
assert (
|
368
|
+
self.layout == "layer_first"
|
369
|
+
), f"Direct IO backend only supports layer_first layout."
|
370
|
+
transfer_kv_direct(
|
371
|
+
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
372
|
+
dst_layers=[
|
373
|
+
device_pool.k_buffer[layer_id],
|
374
|
+
device_pool.v_buffer[layer_id],
|
375
|
+
],
|
376
|
+
src_indices=host_indices,
|
377
|
+
dst_indices=device_indices,
|
378
|
+
page_size=self.page_size,
|
379
|
+
)
|
380
|
+
else:
|
381
|
+
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
382
|
+
|
383
|
+
def backup_from_device_all_layer(
|
384
|
+
self, device_pool, host_indices, device_indices, io_backend
|
385
|
+
):
|
386
|
+
if io_backend == "kernel":
|
387
|
+
if self.layout == "layer_first":
|
388
|
+
transfer_kv_all_layer(
|
389
|
+
src_k_layers=device_pool.k_data_ptrs,
|
390
|
+
dst_k_layers=self.k_data_ptrs,
|
391
|
+
src_v_layers=device_pool.v_data_ptrs,
|
392
|
+
dst_v_layers=self.v_data_ptrs,
|
393
|
+
src_indices=device_indices,
|
394
|
+
dst_indices=host_indices,
|
395
|
+
item_size=self.token_stride_size,
|
396
|
+
num_layers=self.layer_num,
|
397
|
+
)
|
398
|
+
elif self.layout == "page_first":
|
399
|
+
transfer_kv_all_layer_lf_pf(
|
400
|
+
src_k_layers=device_pool.k_data_ptrs,
|
401
|
+
dst_k=self.k_buffer,
|
402
|
+
src_v_layers=device_pool.v_data_ptrs,
|
403
|
+
dst_v=self.v_buffer,
|
404
|
+
src_indices=device_indices,
|
405
|
+
dst_indices=host_indices,
|
406
|
+
item_size=self.token_stride_size,
|
407
|
+
dst_layout_dim=self.layout_dim,
|
408
|
+
num_layers=self.layer_num,
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
412
|
+
elif io_backend == "direct":
|
413
|
+
assert (
|
414
|
+
self.layout == "layer_first"
|
415
|
+
), f"Direct IO backend only supports layer_first layout."
|
416
|
+
transfer_kv_direct(
|
417
|
+
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
418
|
+
dst_layers=self.k_data_refs + self.v_data_refs,
|
419
|
+
src_indices=device_indices,
|
420
|
+
dst_indices=host_indices,
|
421
|
+
page_size=self.page_size,
|
422
|
+
)
|
423
|
+
else:
|
424
|
+
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
425
|
+
|
256
426
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
257
|
-
|
427
|
+
if self.layout == "layer_first":
|
428
|
+
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
429
|
+
elif self.layout == "page_first":
|
430
|
+
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
|
431
|
+
else:
|
432
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
433
|
+
|
434
|
+
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
435
|
+
return torch.zeros(
|
436
|
+
(2, self.layer_num, self.page_size, self.head_num, self.head_dim),
|
437
|
+
dtype=self.dtype,
|
438
|
+
device=self.device,
|
439
|
+
pin_memory=self.pin_memory,
|
440
|
+
).flatten()
|
258
441
|
|
259
442
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
443
|
+
if self.layout == "layer_first":
|
444
|
+
self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
|
445
|
+
data_page.reshape(
|
446
|
+
2,
|
447
|
+
self.layer_num,
|
448
|
+
self.page_size,
|
449
|
+
self.head_num,
|
450
|
+
self.head_dim,
|
451
|
+
)
|
452
|
+
)
|
453
|
+
elif self.layout == "page_first":
|
454
|
+
self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
|
455
|
+
data_page.reshape(
|
456
|
+
2, self.page_size, self.layer_num, self.head_num, self.head_dim
|
457
|
+
)
|
458
|
+
)
|
459
|
+
else:
|
460
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
267
461
|
|
268
462
|
def get_buffer_meta(self, keys, indices):
|
269
463
|
ptr_list = []
|
@@ -302,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
302
496
|
element_size_list = [element_size] * len(key_list)
|
303
497
|
return key_list, ptr_list, element_size_list
|
304
498
|
|
305
|
-
@property
|
306
|
-
def k_buffer(self):
|
307
|
-
return self.kv_buffer[0]
|
308
|
-
|
309
|
-
@property
|
310
|
-
def v_buffer(self):
|
311
|
-
return self.kv_buffer[1]
|
312
|
-
|
313
499
|
|
314
500
|
class MLATokenToKVPoolHost(HostKVCache):
|
315
501
|
device_pool: MLATokenToKVPool
|
@@ -320,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
320
506
|
host_to_device_ratio: float,
|
321
507
|
host_size: int,
|
322
508
|
page_size: int,
|
509
|
+
layout: str,
|
323
510
|
pin_memory: bool = True,
|
324
511
|
device: str = "cpu",
|
325
512
|
):
|
326
513
|
super().__init__(
|
327
|
-
device_pool,
|
514
|
+
device_pool,
|
515
|
+
host_to_device_ratio,
|
516
|
+
host_size,
|
517
|
+
page_size,
|
518
|
+
layout,
|
519
|
+
pin_memory,
|
520
|
+
device,
|
521
|
+
)
|
522
|
+
self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
|
523
|
+
self.data_ptrs = torch.tensor(
|
524
|
+
[x.data_ptr() for x in self.data_refs],
|
525
|
+
dtype=torch.uint64,
|
526
|
+
device=self.device_pool.device,
|
328
527
|
)
|
329
528
|
|
330
529
|
def get_size_per_token(self):
|
@@ -340,28 +539,146 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
340
539
|
)
|
341
540
|
|
342
541
|
def init_kv_buffer(self):
|
343
|
-
|
344
|
-
(
|
542
|
+
if self.layout == "layer_first":
|
543
|
+
dims = (
|
345
544
|
self.layer_num,
|
346
545
|
self.size,
|
347
546
|
1,
|
348
547
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
349
|
-
)
|
548
|
+
)
|
549
|
+
elif self.layout == "page_first":
|
550
|
+
dims = (
|
551
|
+
self.size,
|
552
|
+
self.layer_num,
|
553
|
+
1,
|
554
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
555
|
+
)
|
556
|
+
else:
|
557
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
558
|
+
self.token_stride_size = (
|
559
|
+
self.kv_lora_rank + self.qk_rope_head_dim
|
560
|
+
) * self.dtype.itemsize
|
561
|
+
self.layout_dim = self.token_stride_size * self.layer_num
|
562
|
+
|
563
|
+
return torch.empty(
|
564
|
+
dims,
|
350
565
|
dtype=self.dtype,
|
351
566
|
device=self.device,
|
352
567
|
pin_memory=self.pin_memory,
|
353
568
|
)
|
354
569
|
|
570
|
+
def load_to_device_per_layer(
|
571
|
+
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
572
|
+
):
|
573
|
+
if io_backend == "kernel":
|
574
|
+
if self.layout == "layer_first":
|
575
|
+
transfer_kv_per_layer_mla(
|
576
|
+
src=self.kv_buffer[layer_id],
|
577
|
+
dst=device_pool.kv_buffer[layer_id],
|
578
|
+
src_indices=host_indices,
|
579
|
+
dst_indices=device_indices,
|
580
|
+
item_size=self.token_stride_size,
|
581
|
+
)
|
582
|
+
elif self.layout == "page_first":
|
583
|
+
transfer_kv_per_layer_mla_pf_lf(
|
584
|
+
src=self.kv_buffer,
|
585
|
+
dst=device_pool.kv_buffer[layer_id],
|
586
|
+
src_indices=host_indices,
|
587
|
+
dst_indices=device_indices,
|
588
|
+
item_size=self.token_stride_size,
|
589
|
+
src_layout_dim=self.layout_dim,
|
590
|
+
)
|
591
|
+
else:
|
592
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
593
|
+
elif io_backend == "direct":
|
594
|
+
assert (
|
595
|
+
self.layout == "layer_first"
|
596
|
+
), f"Direct IO backend only supports layer_first layout."
|
597
|
+
transfer_kv_direct(
|
598
|
+
src_layers=[self.kv_buffer[layer_id]],
|
599
|
+
dst_layers=[device_pool.kv_buffer[layer_id]],
|
600
|
+
src_indices=host_indices,
|
601
|
+
dst_indices=device_indices,
|
602
|
+
page_size=self.page_size,
|
603
|
+
)
|
604
|
+
|
605
|
+
def backup_from_device_all_layer(
|
606
|
+
self, device_pool, host_indices, device_indices, io_backend
|
607
|
+
):
|
608
|
+
if io_backend == "kernel":
|
609
|
+
if self.layout == "layer_first":
|
610
|
+
transfer_kv_all_layer_mla(
|
611
|
+
src_layers=device_pool.data_ptrs,
|
612
|
+
dst_layers=self.data_ptrs,
|
613
|
+
src_indices=device_indices,
|
614
|
+
dst_indices=host_indices,
|
615
|
+
item_size=self.token_stride_size,
|
616
|
+
num_layers=self.layer_num,
|
617
|
+
)
|
618
|
+
elif self.layout == "page_first":
|
619
|
+
transfer_kv_all_layer_mla_lf_pf(
|
620
|
+
src_layers=device_pool.data_ptrs,
|
621
|
+
dst_k=self.kv_buffer,
|
622
|
+
src_indices=device_indices,
|
623
|
+
dst_indices=host_indices,
|
624
|
+
item_size=self.token_stride_size,
|
625
|
+
dst_layout_dim=self.layout_dim,
|
626
|
+
num_layers=self.layer_num,
|
627
|
+
)
|
628
|
+
else:
|
629
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
630
|
+
elif io_backend == "direct":
|
631
|
+
assert (
|
632
|
+
self.layout == "layer_first"
|
633
|
+
), f"Direct IO backend only supports layer_first layout."
|
634
|
+
transfer_kv_direct(
|
635
|
+
src_layers=device_pool.kv_buffer,
|
636
|
+
dst_layers=self.data_refs,
|
637
|
+
src_indices=device_indices,
|
638
|
+
dst_indices=host_indices,
|
639
|
+
page_size=self.page_size,
|
640
|
+
)
|
641
|
+
else:
|
642
|
+
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
643
|
+
|
355
644
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
356
|
-
|
645
|
+
if self.layout == "layer_first":
|
646
|
+
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
647
|
+
elif self.layout == "page_first":
|
648
|
+
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
|
649
|
+
else:
|
650
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
651
|
+
|
652
|
+
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
653
|
+
return torch.zeros(
|
654
|
+
(
|
655
|
+
self.layer_num,
|
656
|
+
self.page_size,
|
657
|
+
1,
|
658
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
659
|
+
),
|
660
|
+
dtype=self.dtype,
|
661
|
+
device=self.device,
|
662
|
+
pin_memory=self.pin_memory,
|
663
|
+
).flatten()
|
357
664
|
|
358
665
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
359
|
-
|
360
|
-
self.
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
666
|
+
if self.layout == "layer_first":
|
667
|
+
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
668
|
+
self.layer_num,
|
669
|
+
self.page_size,
|
670
|
+
1,
|
671
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
672
|
+
)
|
673
|
+
elif self.layout == "page_first":
|
674
|
+
self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
|
675
|
+
self.page_size,
|
676
|
+
self.layer_num,
|
677
|
+
1,
|
678
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
679
|
+
)
|
680
|
+
else:
|
681
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
365
682
|
|
366
683
|
def get_buffer_meta(self, keys, indices):
|
367
684
|
ptr_list = []
|
@@ -0,0 +1,163 @@
|
|
1
|
+
import hashlib
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
import uuid
|
6
|
+
from typing import Dict, List, Optional, Tuple, Union
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
11
|
+
|
12
|
+
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
|
13
|
+
|
14
|
+
try:
|
15
|
+
from nixl._api import nixl_agent, nixl_agent_config
|
16
|
+
except ImportError as e:
|
17
|
+
raise ImportError(
|
18
|
+
"Please install NIXL by following the instructions at "
|
19
|
+
"https://github.com/ai-dynamo/nixl/blob/main/README.md "
|
20
|
+
"to use HiCacheNixl storage backend."
|
21
|
+
) from e
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class HiCacheNixl(HiCacheStorage):
|
27
|
+
"""HiCacheNixl provides high-performance storage using NIXL plugins."""
|
28
|
+
|
29
|
+
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
30
|
+
"""Initialize NIXL storage connector."""
|
31
|
+
self.file_manager = (
|
32
|
+
NixlFileManager(file_path)
|
33
|
+
if plugin not in NixlBackendSelection.OBJ_PLUGINS
|
34
|
+
else None
|
35
|
+
)
|
36
|
+
|
37
|
+
agent_config = nixl_agent_config(backends=[])
|
38
|
+
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
|
39
|
+
self.agent = nixl_agent(self.agent_name, agent_config)
|
40
|
+
|
41
|
+
self.backend_selector = NixlBackendSelection(plugin)
|
42
|
+
if not self.backend_selector.create_backend(self.agent):
|
43
|
+
raise RuntimeError("Failed to create NIXL backend")
|
44
|
+
|
45
|
+
self.registration = NixlRegistration(self.agent)
|
46
|
+
|
47
|
+
def _execute_transfer(
|
48
|
+
self, tensors: List[torch.Tensor], keys: List[str], direction: str
|
49
|
+
) -> bool:
|
50
|
+
if len(tensors) != len(keys):
|
51
|
+
logger.error("Mismatch between number of tensors and files/objects")
|
52
|
+
return False
|
53
|
+
|
54
|
+
if not self.registration.register_buffers(tensors):
|
55
|
+
logger.error("Failed to register tensors")
|
56
|
+
return False
|
57
|
+
|
58
|
+
# Get transfer tuples based on backend type
|
59
|
+
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
|
60
|
+
if self.backend_selector.mem_type == "FILE":
|
61
|
+
file_tuples = self.file_manager.files_to_nixl_tuples(keys)
|
62
|
+
if not file_tuples or not self.registration.register_files(file_tuples):
|
63
|
+
logger.error("Failed to prepare files for transfer")
|
64
|
+
return False
|
65
|
+
transfer_tuples = [
|
66
|
+
(x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
|
67
|
+
]
|
68
|
+
else:
|
69
|
+
if not self.registration.register_objects(keys, tensors):
|
70
|
+
logger.error("Failed to register objects")
|
71
|
+
return False
|
72
|
+
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
|
73
|
+
|
74
|
+
try:
|
75
|
+
# Get transfer descriptors
|
76
|
+
if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
|
77
|
+
file_descs := self.agent.get_xfer_descs(
|
78
|
+
transfer_tuples, self.backend_selector.mem_type
|
79
|
+
)
|
80
|
+
) is None:
|
81
|
+
logger.error("Failed to get transfer descriptors")
|
82
|
+
return False
|
83
|
+
|
84
|
+
# Initialize and execute transfer
|
85
|
+
if (
|
86
|
+
xfer_req := self.agent.initialize_xfer(
|
87
|
+
direction, tensor_descs, file_descs, self.agent_name
|
88
|
+
)
|
89
|
+
) is None:
|
90
|
+
logger.error("Failed to create transfer request")
|
91
|
+
return False
|
92
|
+
|
93
|
+
state = self.agent.transfer(xfer_req)
|
94
|
+
while state != "DONE":
|
95
|
+
state = self.agent.check_xfer_state(xfer_req)
|
96
|
+
if state == "ERR":
|
97
|
+
logger.error("Transfer failed")
|
98
|
+
return False
|
99
|
+
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
100
|
+
return True
|
101
|
+
|
102
|
+
except Exception as e:
|
103
|
+
logger.error(f"Failed to execute transfer: {e}")
|
104
|
+
import traceback
|
105
|
+
|
106
|
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
107
|
+
return False
|
108
|
+
|
109
|
+
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
110
|
+
if not keys:
|
111
|
+
return True
|
112
|
+
|
113
|
+
if self.backend_selector.mem_type == "FILE":
|
114
|
+
file_paths = []
|
115
|
+
for key in keys:
|
116
|
+
tensor_path = self.file_manager.get_file_path(key)
|
117
|
+
if not self.file_manager.create_file(tensor_path):
|
118
|
+
logger.error(f"Failed to create file {tensor_path}")
|
119
|
+
return False
|
120
|
+
file_paths.append(tensor_path)
|
121
|
+
return self._execute_transfer(values, file_paths, "WRITE")
|
122
|
+
else:
|
123
|
+
return self._execute_transfer(values, keys, "WRITE")
|
124
|
+
|
125
|
+
def set(self, key: str, value: torch.Tensor) -> bool:
|
126
|
+
return self.batch_set([key], [value])
|
127
|
+
|
128
|
+
def get(
|
129
|
+
self, key: str, dst_tensor: Optional[torch.Tensor] = None
|
130
|
+
) -> torch.Tensor | None:
|
131
|
+
if dst_tensor is None: # To be removed, being compatible with the current API
|
132
|
+
return None
|
133
|
+
result = self.batch_get([key], [dst_tensor])
|
134
|
+
return result[0] if result else None
|
135
|
+
|
136
|
+
def batch_get(
|
137
|
+
self, keys: List[str], dst_tensors: List[torch.Tensor]
|
138
|
+
) -> List[Optional[torch.Tensor]]:
|
139
|
+
if not keys:
|
140
|
+
return []
|
141
|
+
|
142
|
+
if self.backend_selector.mem_type == "FILE":
|
143
|
+
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
144
|
+
success = self._execute_transfer(dst_tensors, file_paths, "READ")
|
145
|
+
else:
|
146
|
+
success = self._execute_transfer(dst_tensors, keys, "READ")
|
147
|
+
return dst_tensors if success else [None] * len(keys)
|
148
|
+
|
149
|
+
def exists(self, key: str) -> bool:
|
150
|
+
tuples = self.registration.create_query_tuples(
|
151
|
+
key,
|
152
|
+
self.backend_selector.mem_type,
|
153
|
+
self.file_manager if self.backend_selector.mem_type == "FILE" else None,
|
154
|
+
)
|
155
|
+
if not tuples:
|
156
|
+
return False
|
157
|
+
|
158
|
+
query_res = self.agent.query_memory(
|
159
|
+
tuples,
|
160
|
+
self.backend_selector.backend_name,
|
161
|
+
mem_type=self.backend_selector.mem_type,
|
162
|
+
)
|
163
|
+
return query_res[0] is not None # can be expanded to multiple keys
|