sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch_server.py +79 -53
  2. sglang/bench_serving.py +186 -14
  3. sglang/profiler.py +0 -1
  4. sglang/srt/conversation.py +38 -5
  5. sglang/srt/disaggregation/decode.py +4 -0
  6. sglang/srt/disaggregation/prefill.py +4 -0
  7. sglang/srt/entrypoints/engine.py +2 -2
  8. sglang/srt/entrypoints/openai/protocol.py +27 -24
  9. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  10. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  11. sglang/srt/entrypoints/tool.py +7 -7
  12. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  15. sglang/srt/harmony_parser.py +588 -0
  16. sglang/srt/hf_transformers_utils.py +16 -7
  17. sglang/srt/layers/attention/ascend_backend.py +218 -111
  18. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  19. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  20. sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
  21. sglang/srt/layers/attention/utils.py +15 -94
  22. sglang/srt/layers/communicator.py +1 -2
  23. sglang/srt/layers/moe/cutlass_moe.py +0 -15
  24. sglang/srt/layers/moe/ep_moe/layer.py +1 -7
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  27. sglang/srt/layers/moe/topk.py +1 -1
  28. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  29. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  31. sglang/srt/layers/quantization/fp8.py +2 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  33. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  35. sglang/srt/layers/quantization/mxfp4.py +16 -23
  36. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  37. sglang/srt/layers/utils.py +0 -14
  38. sglang/srt/lora/lora_manager.py +29 -12
  39. sglang/srt/managers/cache_controller.py +223 -156
  40. sglang/srt/managers/detokenizer_manager.py +5 -0
  41. sglang/srt/managers/io_struct.py +30 -0
  42. sglang/srt/managers/scheduler.py +58 -7
  43. sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
  44. sglang/srt/managers/tokenizer_manager.py +36 -3
  45. sglang/srt/mem_cache/hicache_storage.py +31 -20
  46. sglang/srt/mem_cache/hiradix_cache.py +12 -3
  47. sglang/srt/mem_cache/memory_pool.py +73 -14
  48. sglang/srt/mem_cache/memory_pool_host.py +3 -2
  49. sglang/srt/mem_cache/radix_cache.py +1 -0
  50. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
  51. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
  52. sglang/srt/metrics/collector.py +5 -5
  53. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +12 -3
  56. sglang/srt/models/gpt_oss.py +2 -1
  57. sglang/srt/models/qwen2_5_vl.py +1 -0
  58. sglang/srt/offloader.py +115 -0
  59. sglang/srt/reasoning_parser.py +56 -300
  60. sglang/srt/server_args.py +10 -5
  61. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  62. sglang/srt/utils.py +59 -12
  63. sglang/test/test_cutlass_moe.py +33 -28
  64. sglang/version.py +1 -1
  65. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
  66. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
  67. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
  68. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -34,17 +34,3 @@ class PPMissingLayer(torch.nn.Identity):
34
34
  """
35
35
  input = args[0] if args else next(iter(kwargs.values()))
36
36
  return (input,) if self.return_tuple else input
37
-
38
-
39
- @lru_cache(maxsize=1)
40
- def is_sm100_supported(device=None) -> bool:
41
- return (torch.cuda.get_device_capability(device)[0] == 10) and (
42
- torch.version.cuda >= "12.8"
43
- )
44
-
45
-
46
- @lru_cache(maxsize=1)
47
- def is_sm90_supported(device=None) -> bool:
48
- return (torch.cuda.get_device_capability(device)[0] == 9) and (
49
- torch.version.cuda >= "12.3"
50
- )
@@ -420,20 +420,37 @@ class LoRAManager:
420
420
  ):
421
421
  """Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
422
422
 
423
- if target_modules is not None:
424
- self.target_modules = set(target_modules)
425
- else:
426
- self.target_modules = set()
427
- for config in self.configs.values():
428
- if not isinstance(config.target_modules, list):
423
+ self.target_modules = (
424
+ get_normalized_target_modules(target_modules) if target_modules else set()
425
+ )
426
+
427
+ for lora_id, config in self.configs.items():
428
+ if not isinstance(config.target_modules, list):
429
+ raise ValueError(
430
+ f"SGLang currently only supports inferring LoRA target modules when a list of "
431
+ "suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
432
+ "specify `--lora-target-modules` during server startup. You can specify `all` to "
433
+ "enable all support modules types. "
434
+ )
435
+
436
+ adapter_target_modules = get_normalized_target_modules(
437
+ config.target_modules
438
+ )
439
+
440
+ if target_modules is not None:
441
+ # When `--lora-target-modules` is provided, validate adapter target modules is a subset of the specified target modules.
442
+ if not adapter_target_modules.issubset(self.target_modules):
443
+ unsupported_modules = adapter_target_modules - self.target_modules
444
+ lora_name = self.lora_refs[lora_id].lora_name
429
445
  raise ValueError(
430
- f"SGLang currently only supports inferring LoRA target modules when a list of "
431
- "suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
432
- "specify `--lora-target-modules` during server startup. You can specify `all` to "
433
- "enable all support modules types. "
446
+ f"LoRA adapter '{lora_name}' contains target modules {sorted(unsupported_modules)} "
447
+ f"that are not included in the specified --lora-target-modules {sorted(self.target_modules)}. "
448
+ f"Please update --lora-target-modules to include all required modules: "
449
+ f"{sorted(self.target_modules | adapter_target_modules)}, or use 'all' to enable all supported modules."
434
450
  )
435
- self.target_modules.update(config.target_modules)
436
- self.target_modules = get_normalized_target_modules(self.target_modules)
451
+ else:
452
+ # Otherwise, infer target_modules from adapter configs.
453
+ self.target_modules.update(adapter_target_modules)
437
454
 
438
455
  if max_lora_rank is not None:
439
456
  self.max_lora_rank = max_lora_rank
@@ -22,12 +22,22 @@ from typing import TYPE_CHECKING, List, Optional
22
22
 
23
23
  import torch
24
24
 
25
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
26
+
25
27
  if TYPE_CHECKING:
26
28
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
27
29
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
28
30
 
29
- from sglang.srt.distributed import get_tensor_model_parallel_rank
30
- from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
31
+ from sglang.srt.distributed import (
32
+ get_tensor_model_parallel_rank,
33
+ get_tensor_model_parallel_world_size,
34
+ )
35
+ from sglang.srt.layers.dp_attention import (
36
+ get_attention_tp_rank,
37
+ get_attention_tp_size,
38
+ is_dp_attention_enabled,
39
+ )
40
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
31
41
 
32
42
  logger = logging.getLogger(__name__)
33
43
 
@@ -231,6 +241,8 @@ class HiCacheController:
231
241
  io_backend: str = "",
232
242
  storage_backend: Optional[str] = None,
233
243
  prefetch_threshold: int = 256,
244
+ model_name: Optional[str] = None,
245
+ storage_backend_extra_config: Optional[str] = None,
234
246
  ):
235
247
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
236
248
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -240,28 +252,40 @@ class HiCacheController:
240
252
  self.io_backend = io_backend
241
253
 
242
254
  self.enable_storage = False
243
- self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
255
+
244
256
  # todo: move backend initialization to storage backend module
245
257
  if storage_backend is not None:
246
258
  self.storage_backend_type = storage_backend
247
- from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
259
+ from sglang.srt.mem_cache.hicache_storage import get_hash_str
260
+
261
+ self.get_hash_str = get_hash_str
262
+
263
+ self.storage_config = self._generate_storage_config(
264
+ model_name, storage_backend_extra_config
265
+ )
266
+ # In MLA backend, only one rank needs to backup the KV cache
267
+ self.backup_skip = (
268
+ self.storage_config.is_mla_model
269
+ # todo: for load balancing, decide which rank to backup the KV cache by hash value
270
+ and self.storage_config.tp_rank != 0
271
+ # todo: support other storage backends
272
+ and self.storage_backend_type in ["file", "mooncake"]
273
+ )
248
274
 
249
275
  if storage_backend == "file":
250
- self.storage_backend = HiCacheFile(is_mla=self.is_mla)
251
- self.get_hash_str = get_hash_str
276
+ from sglang.srt.mem_cache.hicache_storage import HiCacheFile
277
+
278
+ self.storage_backend = HiCacheFile(self.storage_config)
252
279
  elif storage_backend == "nixl":
253
280
  from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
254
281
 
255
282
  self.storage_backend = HiCacheNixl()
256
- self.get_hash_str = get_hash_str
257
283
  elif storage_backend == "mooncake":
258
284
  from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
259
285
  MooncakeStore,
260
- get_hash_str_mooncake,
261
286
  )
262
287
 
263
- self.storage_backend = MooncakeStore(is_mla=self.is_mla)
264
- self.get_hash_str = get_hash_str_mooncake
288
+ self.storage_backend = MooncakeStore(self.storage_config)
265
289
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
266
290
  assert self.mem_pool_host.layout == "page_first"
267
291
  elif storage_backend == "hf3fs":
@@ -279,9 +303,8 @@ class HiCacheController:
279
303
  )
280
304
  dtype = mem_pool_host.dtype
281
305
  self.storage_backend = HiCacheHF3FS.from_env_config(
282
- bytes_per_page, dtype
306
+ bytes_per_page, dtype, self.storage_config
283
307
  )
284
- self.get_hash_str = get_hash_str
285
308
  else:
286
309
  raise NotImplementedError(
287
310
  f"Unsupported storage backend: {storage_backend}"
@@ -361,6 +384,40 @@ class HiCacheController:
361
384
  self.prefetch_thread.start()
362
385
  self.backup_thread.start()
363
386
 
387
+ def _generate_storage_config(
388
+ self,
389
+ model_name: Optional[str] = None,
390
+ storage_backend_extra_config: Optional[str] = None,
391
+ ):
392
+
393
+ if is_dp_attention_enabled():
394
+ self.tp_rank = get_attention_tp_rank()
395
+ self.tp_size = get_attention_tp_size()
396
+ else:
397
+ self.tp_rank = get_tensor_model_parallel_rank()
398
+ self.tp_size = get_tensor_model_parallel_world_size()
399
+
400
+ # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
401
+ is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
402
+
403
+ # Parse extra config JSON if provided
404
+ extra_config = None
405
+ if storage_backend_extra_config:
406
+ try:
407
+ import json
408
+
409
+ extra_config = json.loads(storage_backend_extra_config)
410
+ except Exception as e:
411
+ logger.error(f"Invalid backend extra config JSON: {e}")
412
+
413
+ return HiCacheStorageConfig(
414
+ tp_rank=self.tp_rank,
415
+ tp_size=self.tp_size,
416
+ is_mla_model=is_mla_backend,
417
+ model_name=model_name,
418
+ extra_config=extra_config,
419
+ )
420
+
364
421
  def reset(self):
365
422
  self.stop_event.set()
366
423
  self.write_thread.join()
@@ -400,15 +457,6 @@ class HiCacheController:
400
457
  self.prefetch_thread.start()
401
458
  self.backup_thread.start()
402
459
 
403
- @property
404
- def backup_skip(self):
405
- return (
406
- self.is_mla
407
- and get_tensor_model_parallel_rank() != 0
408
- # todo: only support file and mooncake
409
- and self.storage_backend_type in ["file", "mooncake"]
410
- )
411
-
412
460
  def write(
413
461
  self,
414
462
  device_indices: torch.Tensor,
@@ -570,57 +618,92 @@ class HiCacheController:
570
618
  operation.mark_done()
571
619
  return operation.completed_tokens, operation.hash_value
572
620
 
573
- def zerocopy_page_transfer(self, operation, batch_size=8):
621
+ # zero copy
622
+ def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
574
623
  hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
575
- operation.hash_value, operation.host_indices
624
+ hash_values, host_indices
576
625
  )
577
- for i in range(0, len(hashes), batch_size):
578
- page_hashes = hashes[i : i + batch_size]
579
- page_dsts = dsts[i : i + batch_size]
580
- page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
581
- if page_data is None:
626
+ page_data = self.storage_backend.batch_get(hashes, dsts)
627
+ if page_data:
628
+ operation.increment(self.page_size * len(hashes))
629
+ else:
630
+ logger.warning(
631
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
632
+ )
633
+
634
+ # zero copy
635
+ def _mooncake_page_get(self, operation, hash_values, host_indices):
636
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
637
+ hash_values,
638
+ host_indices,
639
+ )
640
+ get_result = self.storage_backend.batch_get(
641
+ key_strs,
642
+ target_location=buffer_ptrs,
643
+ target_sizes=buffer_sizes,
644
+ )
645
+ if get_result != len(hash_values):
646
+ logger.warning(
647
+ f"Prefetch operation {operation.request_id} failed or partially failed."
648
+ )
649
+ if get_result != 0:
650
+ operation.increment(get_result * self.page_size)
651
+
652
+ # non-zero copy
653
+ def _generic_page_get(self, operation, hash_values, host_indices):
654
+ # todo: zero copy
655
+ dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
656
+ hash_values
657
+ )
658
+ page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
659
+ if page_data is None:
660
+ return
661
+ for i in range(len(hash_values)):
662
+ if page_data[i] is None:
582
663
  logger.warning(
583
- f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
664
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
584
665
  )
585
666
  break
586
- completed_tokens = operation.completed_tokens
587
- if operation.increment(self.page_size * len(page_hashes)):
588
- for i in range(len(page_hashes)):
589
- completed_tokens += self.page_size
667
+ if operation.increment(self.page_size):
668
+ self.mem_pool_host.set_from_flat_data_page(
669
+ host_indices[i * self.page_size],
670
+ page_data[i],
671
+ )
590
672
  else:
591
673
  break
592
674
 
593
- def generic_page_transfer(self, operation, batch_size=8):
675
+ def _page_transfer(self, operation):
676
+ # Select the get function and batch size
677
+ if self.is_mooncake_backend():
678
+ get_func = self._mooncake_page_get
679
+ batch_size = 128
680
+ elif self.storage_backend_type == "hf3fs":
681
+ if self.mem_pool_host.layout == "page_first":
682
+ get_func = self._3fs_zero_copy_page_get
683
+ elif self.mem_pool_host.layout == "layer_first":
684
+ get_func = self._generic_page_get
685
+ batch_size = 128
686
+ else:
687
+ get_func = self._generic_page_get
688
+ batch_size = 8
689
+
690
+ # Transfer batch by batch
594
691
  for i in range(0, len(operation.hash_value), batch_size):
595
- page_hashes = operation.hash_value[i : i + batch_size]
596
- # todo: zero copy
597
- dummy_page_dst = [
598
- self.mem_pool_host.get_dummy_flat_data_page()
599
- for _ in range(len(page_hashes))
692
+ batch_hashes = operation.hash_value[i : i + batch_size]
693
+ batch_host_indices = operation.host_indices[
694
+ i * self.page_size : (i + len(batch_hashes)) * self.page_size
600
695
  ]
601
- page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
602
- if page_data is None:
603
- logger.warning(
604
- f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
605
- )
606
- break
607
- completed_tokens = operation.completed_tokens
608
- if operation.increment(self.page_size * len(page_hashes)):
609
- for i in range(len(page_hashes)):
610
- self.mem_pool_host.set_from_flat_data_page(
611
- operation.host_indices[completed_tokens],
612
- page_data[i],
613
- )
614
- completed_tokens += self.page_size
615
- else:
616
- break
617
-
618
- def mooncake_page_transfer(self, operation):
619
- key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
620
- operation.hash_value, operation.host_indices
621
- )
622
- self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
623
- operation.increment(len(operation.hash_value) * self.page_size)
696
+ prev_completed_tokens = operation.completed_tokens
697
+ # Get one batch token, and update the completed_tokens if succeed
698
+ get_func(operation, batch_hashes, batch_host_indices)
699
+ # Check termination
700
+ if (
701
+ operation.completed_tokens
702
+ != prev_completed_tokens + len(batch_hashes) * self.page_size
703
+ ):
704
+ break # Some operations fail or operation terminated by controller
705
+ # release pre-allocated memory
706
+ self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :])
624
707
 
625
708
  def is_mooncake_backend(self):
626
709
  return self.storage_backend_type == "mooncake"
@@ -632,15 +715,7 @@ class HiCacheController:
632
715
  while not self.stop_event.is_set():
633
716
  try:
634
717
  operation = self.prefetch_buffer.get(block=True, timeout=1)
635
- if self.is_mooncake_backend():
636
- self.mooncake_page_transfer(operation)
637
- elif self.storage_backend_type == "hf3fs":
638
- if self.mem_pool_host.layout == "page_first":
639
- self.zerocopy_page_transfer(operation, batch_size=128)
640
- elif self.mem_pool_host.layout == "layer_first":
641
- self.generic_page_transfer(operation, batch_size=128)
642
- else:
643
- self.generic_page_transfer(operation)
718
+ self._page_transfer(operation)
644
719
 
645
720
  if self.tp_world_size > 1:
646
721
  # to ensure all TP workers release the host memory at the same time
@@ -662,6 +737,27 @@ class HiCacheController:
662
737
  # todo: more sophisticated rate limiting based on storage backend performance
663
738
  return True
664
739
 
740
+ def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]:
741
+ last_hash = operation.last_hash
742
+ tokens_to_fetch = operation.token_ids
743
+
744
+ storage_query_count = 0
745
+ remaining_tokens = len(tokens_to_fetch)
746
+ hash_value = []
747
+ while remaining_tokens >= self.page_size:
748
+ last_hash = self.get_hash_str(
749
+ tokens_to_fetch[
750
+ storage_query_count : storage_query_count + self.page_size
751
+ ],
752
+ last_hash,
753
+ )
754
+ hash_value.append(last_hash)
755
+ storage_query_count += self.page_size
756
+ remaining_tokens -= self.page_size
757
+ # deferring to batch exists
758
+ hit_page_num = self.storage_backend.batch_exists(hash_value)
759
+ return hash_value[:hit_page_num], hit_page_num * self.page_size
760
+
665
761
  def prefetch_thread_func(self):
666
762
  """
667
763
  Manage prefetching operations from storage backend to host memory.
@@ -675,38 +771,12 @@ class HiCacheController:
675
771
  if operation is None:
676
772
  continue
677
773
 
678
- storage_hit_count = 0
679
774
  if (
680
775
  operation.host_indices is not None
681
776
  ) and self.prefetch_rate_limit_check():
682
- last_hash = operation.last_hash
683
- tokens_to_fetch = operation.token_ids
684
-
685
- remaining_tokens = len(tokens_to_fetch)
686
- hash_value = []
687
- while remaining_tokens >= self.page_size:
688
- last_hash = self.get_hash_str(
689
- tokens_to_fetch[
690
- storage_hit_count : storage_hit_count + self.page_size
691
- ],
692
- last_hash,
693
- )
694
-
695
- # todo, more unified interface
696
- if not self.is_mooncake_backend():
697
- if not self.storage_backend.exists(last_hash):
698
- break
699
- hash_value.append(last_hash)
700
- storage_hit_count += self.page_size
701
- remaining_tokens -= self.page_size
702
-
703
- if self.is_mooncake_backend():
704
- # deferring to batch exists for mooncake store
705
- exist_result = self.storage_backend.exists(hash_value)
706
- storage_hit_count = (
707
- sum(1 for v in exist_result.values() if v != 0)
708
- * self.page_size
709
- )
777
+ hash_value, storage_hit_count = self._generic_storage_hit_query(
778
+ operation
779
+ )
710
780
 
711
781
  if self.tp_world_size > 1:
712
782
  storage_hit_count_tensor = torch.tensor(
@@ -755,59 +825,64 @@ class HiCacheController:
755
825
  self.backup_queue.put(operation)
756
826
  return operation.id
757
827
 
758
- def zerocopy_page_backup(self, operation, batch_size=8):
759
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
760
- operation.hash_value, operation.host_indices
828
+ # non-zero copy
829
+ def _generic_page_set(self, hash_values, host_indices) -> bool:
830
+ data = [
831
+ self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
832
+ for i in range(len(hash_values))
833
+ ]
834
+ return self.storage_backend.batch_set(hash_values, data)
835
+
836
+ # zero copy
837
+ def _mooncake_page_set(self, hash_values, host_indices) -> bool:
838
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
839
+ hash_values,
840
+ host_indices,
761
841
  )
762
- for i in range(0, len(hashes), batch_size):
763
- page_hashes = hashes[i : i + batch_size]
764
- page_data = dsts[i : i + batch_size]
765
- success = self.storage_backend.batch_set(page_hashes, page_data)
766
- if not success:
767
- logger.warning(f"Failed to write page {page_hashes} to storage.")
768
- break
769
- operation.completed_tokens += self.page_size * len(page_hashes)
842
+ success = self.storage_backend.batch_set(
843
+ key_strs,
844
+ target_location=buffer_ptrs,
845
+ target_sizes=buffer_sizes,
846
+ )
847
+ return success
770
848
 
771
- def generic_page_backup(self, operation, batch_size=8):
849
+ # zero copy
850
+ def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
851
+ hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
852
+ hash_values, host_indices
853
+ )
854
+ return self.storage_backend.batch_set(hashes, dsts)
855
+
856
+ # Backup batch by batch
857
+ def _page_backup(self, operation):
858
+ # Select the set function and batch size
859
+ if self.is_mooncake_backend():
860
+ backup_set_func = self._mooncake_page_set
861
+ batch_size = 128
862
+ elif self.storage_backend_type == "hf3fs":
863
+ if self.mem_pool_host.layout == "page_first":
864
+ backup_set_func = self._3fs_zero_copy_page_set
865
+ elif self.mem_pool_host.layout == "layer_first":
866
+ backup_set_func = self._generic_page_set
867
+ batch_size = 128
868
+ else:
869
+ backup_set_func = self._generic_page_set
870
+ batch_size = 8
871
+ # Backup batch by batch
772
872
  for i in range(0, len(operation.hash_value), batch_size):
773
- page_hashes = operation.hash_value[i : i + batch_size]
774
- page_data = [
775
- self.mem_pool_host.get_flat_data_page(
776
- operation.host_indices[j * self.page_size]
777
- )
778
- for j in range(i, i + len(page_hashes))
873
+ batch_hashes = operation.hash_value[i : i + batch_size]
874
+ batch_host_indices = operation.host_indices[
875
+ i * self.page_size : (i + len(batch_hashes)) * self.page_size
779
876
  ]
780
- success = self.storage_backend.batch_set(page_hashes, page_data)
877
+ # Set one batch token, and record if success.
878
+ # todo: allow partial success
879
+ success = backup_set_func(batch_hashes, batch_host_indices)
781
880
  if not success:
782
- logger.warning(f"Failed to write page {page_hashes} to storage.")
783
- break
784
- operation.completed_tokens += self.page_size * len(page_hashes)
785
-
786
- def mooncake_page_backup(self, operation):
787
- if len(operation.hash_value):
788
- exist_hashvalues = self.storage_backend.exists(operation.hash_value)
789
- indices = operation.host_indices.tolist()
790
- non_exist_keys = []
791
- non_exist_indices = []
792
- for i in range(len(operation.hash_value)):
793
- if not exist_hashvalues[operation.hash_value[i]]:
794
- non_exist_keys.append(operation.hash_value[i])
795
- non_exist_indices.extend(
796
- indices[i * self.page_size : (i + 1) * self.page_size]
797
- )
798
- if len(non_exist_keys) > 0:
799
- key_strs, buffer_ptrs, buffer_sizes = (
800
- self.mem_pool_host.get_buffer_meta(
801
- non_exist_keys, non_exist_indices
802
- )
803
- )
804
- # TODO: check the return value of batch set to see how many tokens are set successfully
805
- self.storage_backend.batch_set(
806
- key_strs,
807
- target_location=buffer_ptrs,
808
- target_sizes=buffer_sizes,
881
+ logger.warning(
882
+ f"Write page to storage: {len(batch_hashes)} pages failed."
809
883
  )
810
- operation.completed_tokens += len(operation.hash_value) * self.page_size
884
+ break
885
+ operation.completed_tokens += self.page_size * len(batch_hashes)
811
886
 
812
887
  def backup_thread_func(self):
813
888
  """
@@ -820,15 +895,7 @@ class HiCacheController:
820
895
  continue
821
896
 
822
897
  if not self.backup_skip:
823
- if self.is_mooncake_backend():
824
- self.mooncake_page_backup(operation)
825
- elif self.storage_backend_type == "hf3fs":
826
- if self.mem_pool_host.layout == "page_first":
827
- self.zerocopy_page_backup(operation, batch_size=128)
828
- elif self.mem_pool_host.layout == "layer_first":
829
- self.generic_page_backup(operation, batch_size=128)
830
- else:
831
- self.generic_page_backup(operation)
898
+ self._page_backup(operation)
832
899
  min_completed_tokens = operation.completed_tokens
833
900
  else:
834
901
  min_completed_tokens = len(operation.token_ids)
@@ -106,6 +106,8 @@ class DetokenizerManager:
106
106
  ]
107
107
  )
108
108
 
109
+ self.is_tool_call_parser_gpt_oss = server_args.tool_call_parser == "gpt-oss"
110
+
109
111
  def event_loop(self):
110
112
  """The event loop that handles requests"""
111
113
  while True:
@@ -133,6 +135,9 @@ class DetokenizerManager:
133
135
 
134
136
  # Trim stop token.
135
137
  if isinstance(matched, int) and isinstance(output, list):
138
+ # 200012 <|call|> is the tool call token and one of eos tokens for gpt-oss model
139
+ if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss:
140
+ return output
136
141
  assert len(output) > 0
137
142
  return output[:-1]
138
143
  return output
@@ -533,6 +533,21 @@ class TokenizedGenerateReqInput:
533
533
  dp_balance_id: int = -1
534
534
 
535
535
 
536
+ @dataclass
537
+ class BatchTokenizedGenerateReqInput:
538
+ # The batch of tokenized requests
539
+ batch: List[TokenizedGenerateReqInput]
540
+
541
+ def __len__(self):
542
+ return len(self.batch)
543
+
544
+ def __getitem__(self, i):
545
+ return self.batch[i]
546
+
547
+ def __iter__(self):
548
+ return iter(self.batch)
549
+
550
+
536
551
  @dataclass
537
552
  class EmbeddingReqInput:
538
553
  # The input prompt. It can be a single prompt or a batch of prompts.
@@ -668,6 +683,21 @@ class TokenizedEmbeddingReqInput:
668
683
  dp_balance_id: int = -1
669
684
 
670
685
 
686
+ @dataclass
687
+ class BatchTokenizedEmbeddingReqInput:
688
+ # The batch of tokenized embedding requests
689
+ batch: List[TokenizedEmbeddingReqInput]
690
+
691
+ def __len__(self):
692
+ return len(self.batch)
693
+
694
+ def __getitem__(self, i):
695
+ return self.batch[i]
696
+
697
+ def __iter__(self):
698
+ return iter(self.batch)
699
+
700
+
671
701
  @dataclass
672
702
  class BatchTokenIDOut:
673
703
  # The request id