sglang 0.4.10__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/srt/configs/model_config.py +1 -0
  3. sglang/srt/disaggregation/launch_lb.py +5 -20
  4. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  5. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  6. sglang/srt/layers/attention/utils.py +6 -1
  7. sglang/srt/layers/moe/ep_moe/layer.py +19 -34
  8. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -2
  9. sglang/srt/layers/quantization/fp8.py +52 -0
  10. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  11. sglang/srt/managers/cache_controller.py +35 -35
  12. sglang/srt/managers/scheduler.py +1 -0
  13. sglang/srt/mem_cache/hicache_storage.py +15 -6
  14. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  15. sglang/srt/mem_cache/memory_pool.py +15 -118
  16. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  17. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  18. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  19. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  20. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  21. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  22. sglang/srt/model_executor/model_runner.py +8 -1
  23. sglang/srt/model_loader/weight_utils.py +2 -0
  24. sglang/srt/models/deepseek_v2.py +5 -6
  25. sglang/srt/models/glm4_moe.py +3 -3
  26. sglang/srt/models/step3_vl.py +0 -3
  27. sglang/srt/server_args.py +40 -6
  28. sglang/srt/utils.py +1 -0
  29. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  30. sglang/version.py +1 -1
  31. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +1 -1
  32. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +35 -30
  33. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  34. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  35. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,21 @@ import psutil
8
8
  import torch
9
9
 
10
10
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
+ from sglang.srt.utils import is_npu
12
+
13
+ _is_npu = is_npu()
14
+ if not _is_npu:
15
+ from sgl_kernel.kvcacheio import (
16
+ transfer_kv_all_layer,
17
+ transfer_kv_all_layer_lf_pf,
18
+ transfer_kv_all_layer_mla,
19
+ transfer_kv_all_layer_mla_lf_pf,
20
+ transfer_kv_direct,
21
+ transfer_kv_per_layer,
22
+ transfer_kv_per_layer_mla,
23
+ transfer_kv_per_layer_mla_pf_lf,
24
+ transfer_kv_per_layer_pf_lf,
25
+ )
11
26
 
12
27
  logger = logging.getLogger(__name__)
13
28
 
@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC):
42
57
  device_pool: KVCache,
43
58
  host_to_device_ratio: float,
44
59
  host_size: int,
60
+ page_size: int,
61
+ layout: str,
45
62
  pin_memory: bool,
46
63
  device: str,
47
- page_size: int,
48
64
  ):
49
65
  self.device_pool = device_pool
50
- self.dtype = device_pool.store_dtype
66
+ self.page_size = page_size
67
+ self.layout = layout
51
68
  self.pin_memory = pin_memory
52
69
  self.device = device
53
- self.page_size = page_size
70
+
71
+ self.dtype = device_pool.store_dtype
54
72
  self.size_per_token = self.get_size_per_token()
55
73
  if host_size > 0:
56
74
  self.size = int(host_size * 1e9 // self.size_per_token)
@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC):
98
116
  def init_kv_buffer(self):
99
117
  raise NotImplementedError()
100
118
 
119
+ @abc.abstractmethod
120
+ def load_to_device_per_layer(
121
+ self, device_pool, host_indices, device_indices, layer_id, io_backend
122
+ ) -> None:
123
+ """
124
+ Load KV data from the host memory pool to the device memory pool for a specific layer.
125
+ """
126
+ raise NotImplementedError()
127
+
128
+ @abc.abstractmethod
129
+ def backup_from_device_all_layer(
130
+ self, device_pool, host_indices, device_indices, io_backend
131
+ ) -> None:
132
+ """
133
+ Backup KV data from the device memory pool to the host memory pool for all layers.
134
+ """
135
+ raise NotImplementedError()
136
+
101
137
  @abc.abstractmethod
102
138
  def get_flat_data_page(self, index) -> torch.Tensor:
103
139
  """
@@ -105,6 +141,14 @@ class HostKVCache(abc.ABC):
105
141
  """
106
142
  raise NotImplementedError()
107
143
 
144
+ @abc.abstractmethod
145
+ def get_dummy_flat_data_page(self) -> torch.Tensor:
146
+ """
147
+ Get a dummy flat data page from the host memory pool.
148
+ This is used for prefetching or initializing empty pages.
149
+ """
150
+ raise NotImplementedError()
151
+
108
152
  @abc.abstractmethod
109
153
  def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
110
154
  """
@@ -230,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
230
274
  host_to_device_ratio: float,
231
275
  host_size: int,
232
276
  page_size: int,
277
+ layout: str,
233
278
  pin_memory: bool = True,
234
279
  device: str = "cpu",
235
280
  ):
236
281
  super().__init__(
237
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
282
+ device_pool,
283
+ host_to_device_ratio,
284
+ host_size,
285
+ page_size,
286
+ layout,
287
+ pin_memory,
288
+ device,
289
+ )
290
+ self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
291
+ self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
292
+ self.k_data_ptrs = torch.tensor(
293
+ [x.data_ptr() for x in self.k_data_refs],
294
+ dtype=torch.uint64,
295
+ device=self.device_pool.device,
296
+ )
297
+ self.v_data_ptrs = torch.tensor(
298
+ [x.data_ptr() for x in self.v_data_refs],
299
+ dtype=torch.uint64,
300
+ device=self.device_pool.device,
238
301
  )
239
302
 
240
303
  def get_size_per_token(self):
@@ -245,25 +308,156 @@ class MHATokenToKVPoolHost(HostKVCache):
245
308
  return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
246
309
 
247
310
  def init_kv_buffer(self):
311
+ if self.layout == "layer_first":
312
+ dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
313
+ elif self.layout == "page_first":
314
+ dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
315
+ else:
316
+ raise ValueError(f"Unsupported layout: {self.layout}")
317
+ self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
318
+ self.layout_dim = self.token_stride_size * self.layer_num
248
319
  return torch.empty(
249
- (2, self.layer_num, self.size, self.head_num, self.head_dim),
320
+ dims,
250
321
  dtype=self.dtype,
251
322
  device=self.device,
252
323
  pin_memory=self.pin_memory,
253
324
  )
254
325
 
255
- # todo, page first memory layout
326
+ @property
327
+ def k_buffer(self):
328
+ return self.kv_buffer[0]
329
+
330
+ @property
331
+ def v_buffer(self):
332
+ return self.kv_buffer[1]
333
+
334
+ def load_to_device_per_layer(
335
+ self,
336
+ device_pool,
337
+ host_indices,
338
+ device_indices,
339
+ layer_id,
340
+ io_backend,
341
+ ):
342
+ if io_backend == "kernel":
343
+ if self.layout == "layer_first":
344
+ transfer_kv_per_layer(
345
+ src_k=self.k_buffer[layer_id],
346
+ dst_k=device_pool.k_buffer[layer_id],
347
+ src_v=self.v_buffer[layer_id],
348
+ dst_v=device_pool.v_buffer[layer_id],
349
+ src_indices=host_indices,
350
+ dst_indices=device_indices,
351
+ item_size=self.token_stride_size,
352
+ )
353
+ elif self.layout == "page_first":
354
+ transfer_kv_per_layer_pf_lf(
355
+ src_k=self.k_buffer,
356
+ dst_k=device_pool.k_buffer[layer_id],
357
+ src_v=self.v_buffer,
358
+ dst_v=device_pool.v_buffer[layer_id],
359
+ src_indices=host_indices,
360
+ dst_indices=device_indices,
361
+ item_size=self.token_stride_size,
362
+ src_layout_dim=self.layout_dim,
363
+ )
364
+ else:
365
+ raise ValueError(f"Unsupported layout: {self.layout}")
366
+ elif io_backend == "direct":
367
+ assert (
368
+ self.layout == "layer_first"
369
+ ), f"Direct IO backend only supports layer_first layout."
370
+ transfer_kv_direct(
371
+ src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
372
+ dst_layers=[
373
+ device_pool.k_buffer[layer_id],
374
+ device_pool.v_buffer[layer_id],
375
+ ],
376
+ src_indices=host_indices,
377
+ dst_indices=device_indices,
378
+ page_size=self.page_size,
379
+ )
380
+ else:
381
+ raise ValueError(f"Unsupported IO backend: {io_backend}")
382
+
383
+ def backup_from_device_all_layer(
384
+ self, device_pool, host_indices, device_indices, io_backend
385
+ ):
386
+ if io_backend == "kernel":
387
+ if self.layout == "layer_first":
388
+ transfer_kv_all_layer(
389
+ src_k_layers=device_pool.k_data_ptrs,
390
+ dst_k_layers=self.k_data_ptrs,
391
+ src_v_layers=device_pool.v_data_ptrs,
392
+ dst_v_layers=self.v_data_ptrs,
393
+ src_indices=device_indices,
394
+ dst_indices=host_indices,
395
+ item_size=self.token_stride_size,
396
+ num_layers=self.layer_num,
397
+ )
398
+ elif self.layout == "page_first":
399
+ transfer_kv_all_layer_lf_pf(
400
+ src_k_layers=device_pool.k_data_ptrs,
401
+ dst_k=self.k_buffer,
402
+ src_v_layers=device_pool.v_data_ptrs,
403
+ dst_v=self.v_buffer,
404
+ src_indices=device_indices,
405
+ dst_indices=host_indices,
406
+ item_size=self.token_stride_size,
407
+ dst_layout_dim=self.layout_dim,
408
+ num_layers=self.layer_num,
409
+ )
410
+ else:
411
+ raise ValueError(f"Unsupported layout: {self.layout}")
412
+ elif io_backend == "direct":
413
+ assert (
414
+ self.layout == "layer_first"
415
+ ), f"Direct IO backend only supports layer_first layout."
416
+ transfer_kv_direct(
417
+ src_layers=device_pool.k_buffer + device_pool.v_buffer,
418
+ dst_layers=self.k_data_refs + self.v_data_refs,
419
+ src_indices=device_indices,
420
+ dst_indices=host_indices,
421
+ page_size=self.page_size,
422
+ )
423
+ else:
424
+ raise ValueError(f"Unsupported IO backend: {io_backend}")
425
+
256
426
  def get_flat_data_page(self, index) -> torch.Tensor:
257
- return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
427
+ if self.layout == "layer_first":
428
+ return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
429
+ elif self.layout == "page_first":
430
+ return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
431
+ else:
432
+ raise ValueError(f"Unsupported layout: {self.layout}")
433
+
434
+ def get_dummy_flat_data_page(self) -> torch.Tensor:
435
+ return torch.zeros(
436
+ (2, self.layer_num, self.page_size, self.head_num, self.head_dim),
437
+ dtype=self.dtype,
438
+ device=self.device,
439
+ pin_memory=self.pin_memory,
440
+ ).flatten()
258
441
 
259
442
  def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
260
- self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
261
- 2,
262
- self.layer_num,
263
- self.page_size,
264
- self.head_num,
265
- self.head_dim,
266
- )
443
+ if self.layout == "layer_first":
444
+ self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
445
+ data_page.reshape(
446
+ 2,
447
+ self.layer_num,
448
+ self.page_size,
449
+ self.head_num,
450
+ self.head_dim,
451
+ )
452
+ )
453
+ elif self.layout == "page_first":
454
+ self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
455
+ data_page.reshape(
456
+ 2, self.page_size, self.layer_num, self.head_num, self.head_dim
457
+ )
458
+ )
459
+ else:
460
+ raise ValueError(f"Unsupported layout: {self.layout}")
267
461
 
268
462
  def get_buffer_meta(self, keys, indices):
269
463
  ptr_list = []
@@ -302,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache):
302
496
  element_size_list = [element_size] * len(key_list)
303
497
  return key_list, ptr_list, element_size_list
304
498
 
305
- @property
306
- def k_buffer(self):
307
- return self.kv_buffer[0]
308
-
309
- @property
310
- def v_buffer(self):
311
- return self.kv_buffer[1]
312
-
313
499
 
314
500
  class MLATokenToKVPoolHost(HostKVCache):
315
501
  device_pool: MLATokenToKVPool
@@ -320,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
320
506
  host_to_device_ratio: float,
321
507
  host_size: int,
322
508
  page_size: int,
509
+ layout: str,
323
510
  pin_memory: bool = True,
324
511
  device: str = "cpu",
325
512
  ):
326
513
  super().__init__(
327
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
514
+ device_pool,
515
+ host_to_device_ratio,
516
+ host_size,
517
+ page_size,
518
+ layout,
519
+ pin_memory,
520
+ device,
521
+ )
522
+ self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
523
+ self.data_ptrs = torch.tensor(
524
+ [x.data_ptr() for x in self.data_refs],
525
+ dtype=torch.uint64,
526
+ device=self.device_pool.device,
328
527
  )
329
528
 
330
529
  def get_size_per_token(self):
@@ -340,28 +539,146 @@ class MLATokenToKVPoolHost(HostKVCache):
340
539
  )
341
540
 
342
541
  def init_kv_buffer(self):
343
- return torch.empty(
344
- (
542
+ if self.layout == "layer_first":
543
+ dims = (
345
544
  self.layer_num,
346
545
  self.size,
347
546
  1,
348
547
  self.kv_lora_rank + self.qk_rope_head_dim,
349
- ),
548
+ )
549
+ elif self.layout == "page_first":
550
+ dims = (
551
+ self.size,
552
+ self.layer_num,
553
+ 1,
554
+ self.kv_lora_rank + self.qk_rope_head_dim,
555
+ )
556
+ else:
557
+ raise ValueError(f"Unsupported layout: {self.layout}")
558
+ self.token_stride_size = (
559
+ self.kv_lora_rank + self.qk_rope_head_dim
560
+ ) * self.dtype.itemsize
561
+ self.layout_dim = self.token_stride_size * self.layer_num
562
+
563
+ return torch.empty(
564
+ dims,
350
565
  dtype=self.dtype,
351
566
  device=self.device,
352
567
  pin_memory=self.pin_memory,
353
568
  )
354
569
 
570
+ def load_to_device_per_layer(
571
+ self, device_pool, host_indices, device_indices, layer_id, io_backend
572
+ ):
573
+ if io_backend == "kernel":
574
+ if self.layout == "layer_first":
575
+ transfer_kv_per_layer_mla(
576
+ src=self.kv_buffer[layer_id],
577
+ dst=device_pool.kv_buffer[layer_id],
578
+ src_indices=host_indices,
579
+ dst_indices=device_indices,
580
+ item_size=self.token_stride_size,
581
+ )
582
+ elif self.layout == "page_first":
583
+ transfer_kv_per_layer_mla_pf_lf(
584
+ src=self.kv_buffer,
585
+ dst=device_pool.kv_buffer[layer_id],
586
+ src_indices=host_indices,
587
+ dst_indices=device_indices,
588
+ item_size=self.token_stride_size,
589
+ src_layout_dim=self.layout_dim,
590
+ )
591
+ else:
592
+ raise ValueError(f"Unsupported layout: {self.layout}")
593
+ elif io_backend == "direct":
594
+ assert (
595
+ self.layout == "layer_first"
596
+ ), f"Direct IO backend only supports layer_first layout."
597
+ transfer_kv_direct(
598
+ src_layers=[self.kv_buffer[layer_id]],
599
+ dst_layers=[device_pool.kv_buffer[layer_id]],
600
+ src_indices=host_indices,
601
+ dst_indices=device_indices,
602
+ page_size=self.page_size,
603
+ )
604
+
605
+ def backup_from_device_all_layer(
606
+ self, device_pool, host_indices, device_indices, io_backend
607
+ ):
608
+ if io_backend == "kernel":
609
+ if self.layout == "layer_first":
610
+ transfer_kv_all_layer_mla(
611
+ src_layers=device_pool.data_ptrs,
612
+ dst_layers=self.data_ptrs,
613
+ src_indices=device_indices,
614
+ dst_indices=host_indices,
615
+ item_size=self.token_stride_size,
616
+ num_layers=self.layer_num,
617
+ )
618
+ elif self.layout == "page_first":
619
+ transfer_kv_all_layer_mla_lf_pf(
620
+ src_layers=device_pool.data_ptrs,
621
+ dst_k=self.kv_buffer,
622
+ src_indices=device_indices,
623
+ dst_indices=host_indices,
624
+ item_size=self.token_stride_size,
625
+ dst_layout_dim=self.layout_dim,
626
+ num_layers=self.layer_num,
627
+ )
628
+ else:
629
+ raise ValueError(f"Unsupported layout: {self.layout}")
630
+ elif io_backend == "direct":
631
+ assert (
632
+ self.layout == "layer_first"
633
+ ), f"Direct IO backend only supports layer_first layout."
634
+ transfer_kv_direct(
635
+ src_layers=device_pool.kv_buffer,
636
+ dst_layers=self.data_refs,
637
+ src_indices=device_indices,
638
+ dst_indices=host_indices,
639
+ page_size=self.page_size,
640
+ )
641
+ else:
642
+ raise ValueError(f"Unsupported IO backend: {io_backend}")
643
+
355
644
  def get_flat_data_page(self, index) -> torch.Tensor:
356
- return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
645
+ if self.layout == "layer_first":
646
+ return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
647
+ elif self.layout == "page_first":
648
+ return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
649
+ else:
650
+ raise ValueError(f"Unsupported layout: {self.layout}")
651
+
652
+ def get_dummy_flat_data_page(self) -> torch.Tensor:
653
+ return torch.zeros(
654
+ (
655
+ self.layer_num,
656
+ self.page_size,
657
+ 1,
658
+ self.kv_lora_rank + self.qk_rope_head_dim,
659
+ ),
660
+ dtype=self.dtype,
661
+ device=self.device,
662
+ pin_memory=self.pin_memory,
663
+ ).flatten()
357
664
 
358
665
  def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
359
- self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
360
- self.layer_num,
361
- self.page_size,
362
- 1,
363
- self.kv_lora_rank + self.qk_rope_head_dim,
364
- )
666
+ if self.layout == "layer_first":
667
+ self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
668
+ self.layer_num,
669
+ self.page_size,
670
+ 1,
671
+ self.kv_lora_rank + self.qk_rope_head_dim,
672
+ )
673
+ elif self.layout == "page_first":
674
+ self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
675
+ self.page_size,
676
+ self.layer_num,
677
+ 1,
678
+ self.kv_lora_rank + self.qk_rope_head_dim,
679
+ )
680
+ else:
681
+ raise ValueError(f"Unsupported layout: {self.layout}")
365
682
 
366
683
  def get_buffer_meta(self, keys, indices):
367
684
  ptr_list = []
@@ -0,0 +1,163 @@
1
+ import hashlib
2
+ import logging
3
+ import os
4
+ import time
5
+ import uuid
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+
10
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
11
+
12
+ from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
13
+
14
+ try:
15
+ from nixl._api import nixl_agent, nixl_agent_config
16
+ except ImportError as e:
17
+ raise ImportError(
18
+ "Please install NIXL by following the instructions at "
19
+ "https://github.com/ai-dynamo/nixl/blob/main/README.md "
20
+ "to use HiCacheNixl storage backend."
21
+ ) from e
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class HiCacheNixl(HiCacheStorage):
27
+ """HiCacheNixl provides high-performance storage using NIXL plugins."""
28
+
29
+ def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
30
+ """Initialize NIXL storage connector."""
31
+ self.file_manager = (
32
+ NixlFileManager(file_path)
33
+ if plugin not in NixlBackendSelection.OBJ_PLUGINS
34
+ else None
35
+ )
36
+
37
+ agent_config = nixl_agent_config(backends=[])
38
+ self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
39
+ self.agent = nixl_agent(self.agent_name, agent_config)
40
+
41
+ self.backend_selector = NixlBackendSelection(plugin)
42
+ if not self.backend_selector.create_backend(self.agent):
43
+ raise RuntimeError("Failed to create NIXL backend")
44
+
45
+ self.registration = NixlRegistration(self.agent)
46
+
47
+ def _execute_transfer(
48
+ self, tensors: List[torch.Tensor], keys: List[str], direction: str
49
+ ) -> bool:
50
+ if len(tensors) != len(keys):
51
+ logger.error("Mismatch between number of tensors and files/objects")
52
+ return False
53
+
54
+ if not self.registration.register_buffers(tensors):
55
+ logger.error("Failed to register tensors")
56
+ return False
57
+
58
+ # Get transfer tuples based on backend type
59
+ tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
60
+ if self.backend_selector.mem_type == "FILE":
61
+ file_tuples = self.file_manager.files_to_nixl_tuples(keys)
62
+ if not file_tuples or not self.registration.register_files(file_tuples):
63
+ logger.error("Failed to prepare files for transfer")
64
+ return False
65
+ transfer_tuples = [
66
+ (x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
67
+ ]
68
+ else:
69
+ if not self.registration.register_objects(keys, tensors):
70
+ logger.error("Failed to register objects")
71
+ return False
72
+ transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
73
+
74
+ try:
75
+ # Get transfer descriptors
76
+ if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
77
+ file_descs := self.agent.get_xfer_descs(
78
+ transfer_tuples, self.backend_selector.mem_type
79
+ )
80
+ ) is None:
81
+ logger.error("Failed to get transfer descriptors")
82
+ return False
83
+
84
+ # Initialize and execute transfer
85
+ if (
86
+ xfer_req := self.agent.initialize_xfer(
87
+ direction, tensor_descs, file_descs, self.agent_name
88
+ )
89
+ ) is None:
90
+ logger.error("Failed to create transfer request")
91
+ return False
92
+
93
+ state = self.agent.transfer(xfer_req)
94
+ while state != "DONE":
95
+ state = self.agent.check_xfer_state(xfer_req)
96
+ if state == "ERR":
97
+ logger.error("Transfer failed")
98
+ return False
99
+ time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
100
+ return True
101
+
102
+ except Exception as e:
103
+ logger.error(f"Failed to execute transfer: {e}")
104
+ import traceback
105
+
106
+ logger.error(f"Traceback: {traceback.format_exc()}")
107
+ return False
108
+
109
+ def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
110
+ if not keys:
111
+ return True
112
+
113
+ if self.backend_selector.mem_type == "FILE":
114
+ file_paths = []
115
+ for key in keys:
116
+ tensor_path = self.file_manager.get_file_path(key)
117
+ if not self.file_manager.create_file(tensor_path):
118
+ logger.error(f"Failed to create file {tensor_path}")
119
+ return False
120
+ file_paths.append(tensor_path)
121
+ return self._execute_transfer(values, file_paths, "WRITE")
122
+ else:
123
+ return self._execute_transfer(values, keys, "WRITE")
124
+
125
+ def set(self, key: str, value: torch.Tensor) -> bool:
126
+ return self.batch_set([key], [value])
127
+
128
+ def get(
129
+ self, key: str, dst_tensor: Optional[torch.Tensor] = None
130
+ ) -> torch.Tensor | None:
131
+ if dst_tensor is None: # To be removed, being compatible with the current API
132
+ return None
133
+ result = self.batch_get([key], [dst_tensor])
134
+ return result[0] if result else None
135
+
136
+ def batch_get(
137
+ self, keys: List[str], dst_tensors: List[torch.Tensor]
138
+ ) -> List[Optional[torch.Tensor]]:
139
+ if not keys:
140
+ return []
141
+
142
+ if self.backend_selector.mem_type == "FILE":
143
+ file_paths = [self.file_manager.get_file_path(key) for key in keys]
144
+ success = self._execute_transfer(dst_tensors, file_paths, "READ")
145
+ else:
146
+ success = self._execute_transfer(dst_tensors, keys, "READ")
147
+ return dst_tensors if success else [None] * len(keys)
148
+
149
+ def exists(self, key: str) -> bool:
150
+ tuples = self.registration.create_query_tuples(
151
+ key,
152
+ self.backend_selector.mem_type,
153
+ self.file_manager if self.backend_selector.mem_type == "FILE" else None,
154
+ )
155
+ if not tuples:
156
+ return False
157
+
158
+ query_res = self.agent.query_memory(
159
+ tuples,
160
+ self.backend_selector.backend_name,
161
+ mem_type=self.backend_selector.mem_type,
162
+ )
163
+ return query_res[0] is not None # can be expanded to multiple keys