sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +79 -53
  3. sglang/bench_serving.py +186 -14
  4. sglang/profiler.py +0 -1
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/longcat_flash.py +104 -0
  7. sglang/srt/configs/model_config.py +12 -0
  8. sglang/srt/connector/__init__.py +1 -1
  9. sglang/srt/connector/base_connector.py +1 -2
  10. sglang/srt/connector/redis.py +2 -2
  11. sglang/srt/connector/serde/__init__.py +1 -1
  12. sglang/srt/connector/serde/safe_serde.py +4 -3
  13. sglang/srt/conversation.py +38 -5
  14. sglang/srt/disaggregation/ascend/conn.py +75 -0
  15. sglang/srt/disaggregation/launch_lb.py +0 -13
  16. sglang/srt/disaggregation/mini_lb.py +33 -8
  17. sglang/srt/disaggregation/prefill.py +1 -1
  18. sglang/srt/distributed/parallel_state.py +24 -14
  19. sglang/srt/entrypoints/engine.py +19 -12
  20. sglang/srt/entrypoints/http_server.py +174 -34
  21. sglang/srt/entrypoints/openai/protocol.py +87 -24
  22. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  23. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  24. sglang/srt/eplb/eplb_manager.py +26 -2
  25. sglang/srt/eplb/expert_distribution.py +29 -2
  26. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  27. sglang/srt/function_call/function_call_parser.py +2 -0
  28. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  29. sglang/srt/harmony_parser.py +588 -0
  30. sglang/srt/hf_transformers_utils.py +26 -7
  31. sglang/srt/layers/activation.py +12 -0
  32. sglang/srt/layers/attention/ascend_backend.py +374 -136
  33. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  34. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  35. sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
  36. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  38. sglang/srt/layers/communicator.py +1 -2
  39. sglang/srt/layers/layernorm.py +28 -3
  40. sglang/srt/layers/linear.py +3 -2
  41. sglang/srt/layers/logits_processor.py +1 -1
  42. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  43. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  44. sglang/srt/layers/moe/ep_moe/layer.py +13 -13
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/topk.py +35 -12
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  49. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  50. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  51. sglang/srt/layers/quantization/fp8.py +2 -1
  52. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  53. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  54. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  55. sglang/srt/layers/quantization/mxfp4.py +25 -27
  56. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  57. sglang/srt/layers/quantization/utils.py +13 -0
  58. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  59. sglang/srt/layers/rotary_embedding.py +28 -1
  60. sglang/srt/layers/sampler.py +29 -5
  61. sglang/srt/layers/utils.py +0 -14
  62. sglang/srt/managers/cache_controller.py +237 -204
  63. sglang/srt/managers/detokenizer_manager.py +48 -2
  64. sglang/srt/managers/io_struct.py +57 -0
  65. sglang/srt/managers/mm_utils.py +5 -1
  66. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  67. sglang/srt/managers/scheduler.py +94 -9
  68. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  69. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  70. sglang/srt/managers/tokenizer_manager.py +122 -42
  71. sglang/srt/mem_cache/chunk_cache.py +1 -1
  72. sglang/srt/mem_cache/hicache_storage.py +51 -23
  73. sglang/srt/mem_cache/hiradix_cache.py +87 -71
  74. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  75. sglang/srt/mem_cache/memory_pool.py +77 -14
  76. sglang/srt/mem_cache/memory_pool_host.py +4 -5
  77. sglang/srt/mem_cache/radix_cache.py +6 -4
  78. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  79. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
  80. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
  81. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  82. sglang/srt/model_executor/model_runner.py +6 -5
  83. sglang/srt/model_loader/loader.py +15 -24
  84. sglang/srt/model_loader/utils.py +12 -0
  85. sglang/srt/models/deepseek_v2.py +38 -13
  86. sglang/srt/models/gpt_oss.py +2 -15
  87. sglang/srt/models/llama_eagle3.py +4 -0
  88. sglang/srt/models/longcat_flash.py +1015 -0
  89. sglang/srt/models/longcat_flash_nextn.py +691 -0
  90. sglang/srt/models/qwen2.py +26 -3
  91. sglang/srt/models/qwen2_5_vl.py +66 -41
  92. sglang/srt/models/qwen2_moe.py +22 -2
  93. sglang/srt/models/transformers.py +1 -1
  94. sglang/srt/multimodal/processors/base_processor.py +4 -2
  95. sglang/srt/reasoning_parser.py +56 -300
  96. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  97. sglang/srt/server_args.py +122 -56
  98. sglang/srt/speculative/eagle_worker.py +28 -8
  99. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  100. sglang/srt/utils.py +73 -5
  101. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  102. sglang/version.py +1 -1
  103. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
  104. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
  105. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  106. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -67,6 +67,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
67
  from sglang.srt.layers.moe import initialize_moe_config
68
68
  from sglang.srt.managers.io_struct import (
69
69
  AbortReq,
70
+ BatchTokenizedEmbeddingReqInput,
71
+ BatchTokenizedGenerateReqInput,
72
+ ClearHiCacheReqInput,
73
+ ClearHiCacheReqOutput,
70
74
  CloseSessionReqInput,
71
75
  ExpertDistributionReq,
72
76
  ExpertDistributionReqOutput,
@@ -80,6 +84,8 @@ from sglang.srt.managers.io_struct import (
80
84
  InitWeightsUpdateGroupReqInput,
81
85
  LoadLoRAAdapterReqInput,
82
86
  LoadLoRAAdapterReqOutput,
87
+ MultiTokenizerRegisterReq,
88
+ MultiTokenizerWarpper,
83
89
  OpenSessionReqInput,
84
90
  OpenSessionReqOutput,
85
91
  ProfileReq,
@@ -253,7 +259,6 @@ class Scheduler(
253
259
  # Init inter-process communication
254
260
  context = zmq.Context(2)
255
261
  self.idle_sleeper = None
256
-
257
262
  if self.pp_rank == 0 and self.attn_tp_rank == 0:
258
263
  self.recv_from_tokenizer = get_zmq_socket(
259
264
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
@@ -510,7 +515,10 @@ class Scheduler(
510
515
  [
511
516
  (TokenizedGenerateReqInput, self.handle_generate_request),
512
517
  (TokenizedEmbeddingReqInput, self.handle_embedding_request),
518
+ (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
519
+ (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
513
520
  (FlushCacheReqInput, self.flush_cache_wrapped),
521
+ (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
514
522
  (AbortReq, self.abort_request),
515
523
  (OpenSessionReqInput, self.open_session),
516
524
  (CloseSessionReqInput, self.close_session),
@@ -533,6 +541,7 @@ class Scheduler(
533
541
  (ExpertDistributionReq, self.expert_distribution_handle),
534
542
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
535
543
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
544
+ (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
536
545
  ]
537
546
  )
538
547
 
@@ -623,6 +632,8 @@ class Scheduler(
623
632
  hicache_mem_layout=server_args.hicache_mem_layout,
624
633
  hicache_storage_backend=server_args.hicache_storage_backend,
625
634
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
635
+ model_name=server_args.served_model_name,
636
+ storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
626
637
  )
627
638
  self.tp_worker.register_hicache_layer_transfer_counter(
628
639
  self.tree_cache.cache_controller.layer_done_counter
@@ -1018,14 +1029,26 @@ class Scheduler(
1018
1029
  req
1019
1030
  for req in recv_reqs
1020
1031
  if isinstance(
1021
- req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1032
+ req,
1033
+ (
1034
+ TokenizedGenerateReqInput,
1035
+ TokenizedEmbeddingReqInput,
1036
+ BatchTokenizedGenerateReqInput,
1037
+ BatchTokenizedEmbeddingReqInput,
1038
+ ),
1022
1039
  )
1023
1040
  ]
1024
1041
  control_reqs = [
1025
1042
  req
1026
1043
  for req in recv_reqs
1027
1044
  if not isinstance(
1028
- req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1045
+ req,
1046
+ (
1047
+ TokenizedGenerateReqInput,
1048
+ TokenizedEmbeddingReqInput,
1049
+ BatchTokenizedGenerateReqInput,
1050
+ BatchTokenizedEmbeddingReqInput,
1051
+ ),
1029
1052
  )
1030
1053
  ]
1031
1054
  else:
@@ -1080,6 +1103,17 @@ class Scheduler(
1080
1103
  )
1081
1104
  self.send_to_tokenizer.send_pyobj(abort_req)
1082
1105
  continue
1106
+
1107
+ # If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
1108
+ if isinstance(recv_req, MultiTokenizerWarpper):
1109
+ worker_id = recv_req.worker_id
1110
+ recv_req = recv_req.obj
1111
+ output = self._request_dispatcher(recv_req)
1112
+ if output is not None:
1113
+ output = MultiTokenizerWarpper(worker_id, output)
1114
+ self.send_to_tokenizer.send_pyobj(output)
1115
+ continue
1116
+
1083
1117
  output = self._request_dispatcher(recv_req)
1084
1118
  if output is not None:
1085
1119
  if isinstance(output, RpcReqOutput):
@@ -1253,6 +1287,17 @@ class Scheduler(
1253
1287
  else:
1254
1288
  self._add_request_to_queue(req)
1255
1289
 
1290
+ def handle_batch_generate_request(
1291
+ self,
1292
+ recv_req: BatchTokenizedGenerateReqInput,
1293
+ ):
1294
+ """Handle optimized batch generate request."""
1295
+ logger.debug(f"Processing batch generate request with {len(recv_req)} requests")
1296
+
1297
+ # Process each request in the batch
1298
+ for tokenized_req in recv_req:
1299
+ self.handle_generate_request(tokenized_req)
1300
+
1256
1301
  def _add_request_to_queue(self, req: Req):
1257
1302
  req.queue_time_start = time.perf_counter()
1258
1303
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -1269,10 +1314,11 @@ class Scheduler(
1269
1314
  def _prefetch_kvcache(self, req: Req):
1270
1315
  if self.enable_hicache_storage:
1271
1316
  req.init_next_round_input(self.tree_cache)
1272
- last_hash = req.last_host_node.get_last_hash_value()
1273
- matched_len = len(req.prefix_indices) + req.host_hit_length
1274
- # todo, free-form fetching, calculating hash keys on the fly
1275
- if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1317
+ if req.last_node.backuped:
1318
+ # only to initiate the prefetch if the last node is backuped
1319
+ # otherwise, the allocated GPU memory must be locked for integrity
1320
+ last_hash = req.last_host_node.get_last_hash_value()
1321
+ matched_len = len(req.prefix_indices) + req.host_hit_length
1276
1322
  new_input_tokens = req.fill_ids[matched_len:]
1277
1323
  self.tree_cache.prefetch_from_storage(
1278
1324
  req.rid, req.last_host_node, new_input_tokens, last_hash
@@ -1335,6 +1381,19 @@ class Scheduler(
1335
1381
  req.logprob_start_len = len(req.origin_input_ids) - 1
1336
1382
  self._add_request_to_queue(req)
1337
1383
 
1384
+ def handle_batch_embedding_request(
1385
+ self,
1386
+ recv_req: BatchTokenizedEmbeddingReqInput,
1387
+ ):
1388
+ """Handle optimized batch embedding request."""
1389
+ logger.debug(
1390
+ f"Processing batch embedding request with {len(recv_req)} requests"
1391
+ )
1392
+
1393
+ # Process each request in the batch
1394
+ for tokenized_req in recv_req:
1395
+ self.handle_embedding_request(tokenized_req)
1396
+
1338
1397
  def self_check_during_idle(self):
1339
1398
  self.check_memory()
1340
1399
  self.check_tree_cache()
@@ -1460,7 +1519,7 @@ class Scheduler(
1460
1519
  # Move the chunked request out of the batch so that we can merge
1461
1520
  # only finished requests to running_batch.
1462
1521
  chunked_req_to_exclude.add(self.chunked_req)
1463
- self.tree_cache.cache_unfinished_req(self.chunked_req)
1522
+ self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1464
1523
  # chunked request keeps its rid but will get a new req_pool_idx
1465
1524
  self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1466
1525
  if self.last_batch and self.last_batch.forward_mode.is_extend():
@@ -2164,6 +2223,16 @@ class Scheduler(
2164
2223
  success = self.flush_cache()
2165
2224
  return FlushCacheReqOutput(success=success)
2166
2225
 
2226
+ def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
2227
+ if self.enable_hierarchical_cache:
2228
+ self.tree_cache.clear_storage_backend()
2229
+ logger.info("Hierarchical cache cleared successfully!")
2230
+ if_success = True
2231
+ else:
2232
+ logging.warning("Hierarchical cache is not enabled.")
2233
+ if_success = False
2234
+ return ClearHiCacheReqOutput(success=if_success)
2235
+
2167
2236
  def flush_cache(self):
2168
2237
  """Flush the memory pool and cache."""
2169
2238
  if (
@@ -2335,6 +2404,10 @@ class Scheduler(
2335
2404
  # We still need to send something back to TokenizerManager to clean up the state.
2336
2405
  req = self.waiting_queue.pop(i)
2337
2406
  self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2407
+ # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2408
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
2409
+ self.tree_cache.cache_finished_req(req)
2410
+
2338
2411
  logger.debug(f"Abort queued request. {req.rid=}")
2339
2412
 
2340
2413
  # Delete the requests in the grammar queue
@@ -2414,6 +2487,10 @@ class Scheduler(
2414
2487
  result = self.tp_worker.unload_lora_adapter(recv_req)
2415
2488
  return result
2416
2489
 
2490
+ def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
2491
+ self.send_to_detokenizer.send_pyobj(recv_req)
2492
+ return recv_req
2493
+
2417
2494
  def slow_down(self, recv_req: SlowDownReqInput):
2418
2495
  t = recv_req.forward_sleep_time
2419
2496
  if t is not None and t <= 0:
@@ -2513,7 +2590,15 @@ def is_health_check_generate_req(recv_req):
2513
2590
 
2514
2591
 
2515
2592
  def is_work_request(recv_req):
2516
- return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2593
+ return isinstance(
2594
+ recv_req,
2595
+ (
2596
+ TokenizedGenerateReqInput,
2597
+ TokenizedEmbeddingReqInput,
2598
+ BatchTokenizedGenerateReqInput,
2599
+ BatchTokenizedEmbeddingReqInput,
2600
+ ),
2601
+ )
2517
2602
 
2518
2603
 
2519
2604
  def run_scheduler_process(
@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
93
93
  # This updates radix so others can match
94
94
  self.tree_cache.cache_unfinished_req(req)
95
95
 
96
- if req.return_logprob:
96
+ if batch.return_logprob:
97
97
  assert extend_logprob_start_len_per_req is not None
98
98
  assert extend_input_len_per_req is not None
99
99
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
100
100
  extend_input_len = extend_input_len_per_req[i]
101
101
  num_input_logprobs = extend_input_len - extend_logprob_start_len
102
- self.add_logprob_return_values(
103
- i,
104
- req,
105
- logprob_pt,
106
- next_token_ids,
107
- num_input_logprobs,
108
- logits_output,
109
- )
102
+ if req.return_logprob:
103
+ self.add_logprob_return_values(
104
+ i,
105
+ req,
106
+ logprob_pt,
107
+ next_token_ids,
108
+ num_input_logprobs,
109
+ logits_output,
110
+ )
110
111
  logprob_pt += num_input_logprobs
111
112
 
112
113
  if (
@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
146
147
  skip_stream_req = req
147
148
 
148
149
  # Incrementally update input logprobs.
149
- if req.return_logprob:
150
+ if batch.return_logprob:
150
151
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
151
152
  extend_input_len = extend_input_len_per_req[i]
152
153
  if extend_logprob_start_len < extend_input_len:
@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
154
155
  num_input_logprobs = (
155
156
  extend_input_len - extend_logprob_start_len
156
157
  )
157
- self.add_input_logprob_return_values(
158
- i,
159
- req,
160
- logits_output,
161
- logprob_pt,
162
- num_input_logprobs,
163
- last_prefill_chunk=False,
164
- )
158
+ if req.return_logprob:
159
+ self.add_input_logprob_return_values(
160
+ i,
161
+ req,
162
+ logits_output,
163
+ logprob_pt,
164
+ num_input_logprobs,
165
+ last_prefill_chunk=False,
166
+ )
165
167
  logprob_pt += num_input_logprobs
166
168
 
167
169
  self.set_next_batch_sampling_info_done(batch)
@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
121
121
  url = params["url"]
122
122
 
123
123
  worker = self.tp_worker.worker
124
-
125
124
  worker.model_runner.save_remote_model(url)
126
125
 
126
+ if self.draft_worker is not None:
127
+ draft_url = params.get("draft_url", None)
128
+ assert (
129
+ draft_url is not None
130
+ ), "draft_url must be provided when draft model is enabled"
131
+ draft_worker = self.draft_worker.worker
132
+ draft_worker.model_runner.save_remote_model(draft_url)
133
+
127
134
  def save_sharded_model(self, params):
128
135
  worker = self.tp_worker.worker
129
136
 
@@ -71,6 +71,10 @@ from sglang.srt.managers.io_struct import (
71
71
  BatchMultimodalOut,
72
72
  BatchStrOut,
73
73
  BatchTokenIDOut,
74
+ BatchTokenizedEmbeddingReqInput,
75
+ BatchTokenizedGenerateReqInput,
76
+ ClearHiCacheReqInput,
77
+ ClearHiCacheReqOutput,
74
78
  CloseSessionReqInput,
75
79
  ConfigureLoggingReq,
76
80
  EmbeddingReqInput,
@@ -90,6 +94,7 @@ from sglang.srt.managers.io_struct import (
90
94
  LoadLoRAAdapterReqInput,
91
95
  LoadLoRAAdapterReqOutput,
92
96
  LoRAUpdateResult,
97
+ MultiTokenizerWarpper,
93
98
  OpenSessionReqInput,
94
99
  OpenSessionReqOutput,
95
100
  ProfileReq,
@@ -127,6 +132,7 @@ from sglang.srt.utils import (
127
132
  dataclass_to_string_truncated,
128
133
  freeze_gc,
129
134
  get_bool_env_var,
135
+ get_origin_rid,
130
136
  get_zmq_socket,
131
137
  kill_process_tree,
132
138
  )
@@ -262,9 +268,15 @@ class TokenizerManager:
262
268
  self.recv_from_detokenizer = get_zmq_socket(
263
269
  context, zmq.PULL, port_args.tokenizer_ipc_name, True
264
270
  )
265
- self.send_to_scheduler = get_zmq_socket(
266
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
267
- )
271
+ if self.server_args.tokenizer_worker_num > 1:
272
+ # Use tokenizer_worker_ipc_name in multi-tokenizer mode
273
+ self.send_to_scheduler = get_zmq_socket(
274
+ context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
275
+ )
276
+ else:
277
+ self.send_to_scheduler = get_zmq_socket(
278
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
279
+ )
268
280
 
269
281
  # Request states
270
282
  self.no_create_loop = False
@@ -308,35 +320,7 @@ class TokenizerManager:
308
320
  self.lora_update_lock = asyncio.Lock()
309
321
 
310
322
  # For PD disaggregtion
311
- self.disaggregation_mode = DisaggregationMode(
312
- self.server_args.disaggregation_mode
313
- )
314
- self.disaggregation_transfer_backend = TransferBackend(
315
- self.server_args.disaggregation_transfer_backend
316
- )
317
- # Start kv boostrap server on prefill
318
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
319
- # only start bootstrap server on prefill tm
320
- kv_bootstrap_server_class = get_kv_class(
321
- self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
322
- )
323
- self.bootstrap_server = kv_bootstrap_server_class(
324
- self.server_args.disaggregation_bootstrap_port
325
- )
326
- is_create_store = (
327
- self.server_args.node_rank == 0
328
- and self.server_args.disaggregation_transfer_backend == "ascend"
329
- )
330
- if is_create_store:
331
- try:
332
- from mf_adapter import create_config_store
333
-
334
- ascend_url = os.getenv("ASCEND_MF_STORE_URL")
335
- create_config_store(ascend_url)
336
- except Exception as e:
337
- error_message = f"Failed create mf store, invalid ascend_url."
338
- error_message += f" With exception {e}"
339
- raise error_message
323
+ self.init_disaggregation()
340
324
 
341
325
  # For load balancing
342
326
  self.current_load = 0
@@ -384,6 +368,9 @@ class TokenizerManager:
384
368
  self.flush_cache_communicator = _Communicator(
385
369
  self.send_to_scheduler, server_args.dp_size
386
370
  )
371
+ self.clear_hicache_storage_communicator = _Communicator(
372
+ self.send_to_scheduler, server_args.dp_size
373
+ )
387
374
  self.profile_communicator = _Communicator(
388
375
  self.send_to_scheduler, server_args.dp_size
389
376
  )
@@ -445,6 +432,10 @@ class TokenizerManager:
445
432
  SlowDownReqOutput,
446
433
  self.slow_down_communicator.handle_recv,
447
434
  ),
435
+ (
436
+ ClearHiCacheReqOutput,
437
+ self.clear_hicache_storage_communicator.handle_recv,
438
+ ),
448
439
  (
449
440
  FlushCacheReqOutput,
450
441
  self.flush_cache_communicator.handle_recv,
@@ -477,6 +468,37 @@ class TokenizerManager:
477
468
  ]
478
469
  )
479
470
 
471
+ def init_disaggregation(self):
472
+ self.disaggregation_mode = DisaggregationMode(
473
+ self.server_args.disaggregation_mode
474
+ )
475
+ self.disaggregation_transfer_backend = TransferBackend(
476
+ self.server_args.disaggregation_transfer_backend
477
+ )
478
+ # Start kv boostrap server on prefill
479
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
480
+ # only start bootstrap server on prefill tm
481
+ kv_bootstrap_server_class = get_kv_class(
482
+ self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
483
+ )
484
+ self.bootstrap_server = kv_bootstrap_server_class(
485
+ self.server_args.disaggregation_bootstrap_port
486
+ )
487
+ is_create_store = (
488
+ self.server_args.node_rank == 0
489
+ and self.server_args.disaggregation_transfer_backend == "ascend"
490
+ )
491
+ if is_create_store:
492
+ try:
493
+ from mf_adapter import create_config_store
494
+
495
+ ascend_url = os.getenv("ASCEND_MF_STORE_URL")
496
+ create_config_store(ascend_url)
497
+ except Exception as e:
498
+ error_message = f"Failed create mf store, invalid ascend_url."
499
+ error_message += f" With exception {e}"
500
+ raise error_message
501
+
480
502
  async def generate_request(
481
503
  self,
482
504
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -486,6 +508,15 @@ class TokenizerManager:
486
508
  self.auto_create_handle_loop()
487
509
  obj.normalize_batch_and_arguments()
488
510
 
511
+ if self.server_args.tokenizer_worker_num > 1:
512
+ # Modify rid, add worker_id
513
+ if isinstance(obj.rid, list):
514
+ # If it's an array, add worker_id prefix to each element
515
+ obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
516
+ else:
517
+ # If it's a single value, add worker_id prefix
518
+ obj.rid = f"{self.worker_id}_{obj.rid}"
519
+
489
520
  if self.log_requests:
490
521
  max_length, skip_names, _ = self.log_request_metadata
491
522
  logger.info(
@@ -768,6 +799,30 @@ class TokenizerManager:
768
799
  self.rid_to_state[obj.rid] = state
769
800
  return state
770
801
 
802
+ def _send_batch_request(
803
+ self,
804
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
805
+ tokenized_objs: List[
806
+ Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
807
+ ],
808
+ created_time: Optional[float] = None,
809
+ ):
810
+ """Send a batch of tokenized requests as a single batched request to the scheduler."""
811
+ if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
812
+ batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
813
+ else:
814
+ batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)
815
+
816
+ self.send_to_scheduler.send_pyobj(batch_req)
817
+
818
+ # Create states for each individual request in the batch
819
+ for i, tokenized_obj in enumerate(tokenized_objs):
820
+ tmp_obj = obj[i]
821
+ state = ReqState(
822
+ [], False, asyncio.Event(), tmp_obj, created_time=created_time
823
+ )
824
+ self.rid_to_state[tmp_obj.rid] = state
825
+
771
826
  async def _wait_one_response(
772
827
  self,
773
828
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -870,10 +925,17 @@ class TokenizerManager:
870
925
 
871
926
  tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
872
927
 
873
- for i, tokenized_obj in enumerate(tokenized_objs):
928
+ # Send as a single batched request
929
+ self._send_batch_request(obj, tokenized_objs, created_time)
930
+
931
+ # Set up generators for each request in the batch
932
+ for i in range(batch_size):
874
933
  tmp_obj = obj[i]
875
- state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
876
- generators.append(self._wait_one_response(tmp_obj, state, request))
934
+ generators.append(
935
+ self._wait_one_response(
936
+ tmp_obj, self.rid_to_state[tmp_obj.rid], request
937
+ )
938
+ )
877
939
  rids.append(tmp_obj.rid)
878
940
  else:
879
941
  # Sequential tokenization and processing
@@ -955,6 +1017,13 @@ class TokenizerManager:
955
1017
  async def flush_cache(self) -> FlushCacheReqOutput:
956
1018
  return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
957
1019
 
1020
+ async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
1021
+ """Clear the hierarchical cache storage."""
1022
+ # Delegate to the scheduler to handle HiCacheStorage clearing
1023
+ return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
1024
+ 0
1025
+ ]
1026
+
958
1027
  def abort_request(self, rid: str = "", abort_all: bool = False):
959
1028
  if not abort_all and rid not in self.rid_to_state:
960
1029
  return
@@ -1047,6 +1116,8 @@ class TokenizerManager:
1047
1116
  async def _wait_for_model_update_from_disk(
1048
1117
  self, obj: UpdateWeightFromDiskReqInput
1049
1118
  ) -> Tuple[bool, str]:
1119
+ if self.server_args.tokenizer_worker_num > 1:
1120
+ obj = MultiTokenizerWarpper(self.worker_id, obj)
1050
1121
  self.send_to_scheduler.send_pyobj(obj)
1051
1122
  self.model_update_result = asyncio.Future()
1052
1123
  if self.server_args.dp_size == 1:
@@ -1266,6 +1337,8 @@ class TokenizerManager:
1266
1337
  elif obj.session_id in self.session_futures:
1267
1338
  return None
1268
1339
 
1340
+ if self.server_args.tokenizer_worker_num > 1:
1341
+ obj = MultiTokenizerWarpper(self.worker_id, obj)
1269
1342
  self.send_to_scheduler.send_pyobj(obj)
1270
1343
 
1271
1344
  self.session_futures[obj.session_id] = asyncio.Future()
@@ -1286,13 +1359,11 @@ class TokenizerManager:
1286
1359
  # Many DP ranks
1287
1360
  return [res.internal_state for res in responses]
1288
1361
 
1289
- async def set_internal_state(
1290
- self, obj: SetInternalStateReq
1291
- ) -> SetInternalStateReqOutput:
1362
+ async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
1292
1363
  responses: List[SetInternalStateReqOutput] = (
1293
1364
  await self.set_internal_state_communicator(obj)
1294
1365
  )
1295
- return [res.internal_state for res in responses]
1366
+ return [res.updated for res in responses]
1296
1367
 
1297
1368
  async def get_load(self) -> dict:
1298
1369
  # TODO(lsyin): fake load report server
@@ -1543,7 +1614,6 @@ class TokenizerManager:
1543
1614
 
1544
1615
  async def handle_loop(self):
1545
1616
  """The event loop that handles requests"""
1546
-
1547
1617
  while True:
1548
1618
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1549
1619
  self._result_dispatcher(recv_obj)
@@ -1563,9 +1633,12 @@ class TokenizerManager:
1563
1633
  )
1564
1634
  continue
1565
1635
 
1636
+ origin_rid = rid
1637
+ if self.server_args.tokenizer_worker_num > 1:
1638
+ origin_rid = get_origin_rid(rid)
1566
1639
  # Build meta_info and return value
1567
1640
  meta_info = {
1568
- "id": rid,
1641
+ "id": origin_rid,
1569
1642
  "finish_reason": recv_obj.finished_reasons[i],
1570
1643
  "prompt_tokens": recv_obj.prompt_tokens[i],
1571
1644
  "weight_version": self.server_args.weight_version,
@@ -1871,6 +1944,9 @@ class TokenizerManager:
1871
1944
  if is_health_check_generate_req(recv_obj):
1872
1945
  return
1873
1946
  state = self.rid_to_state[recv_obj.rid]
1947
+ origin_rid = recv_obj.rid
1948
+ if self.server_args.tokenizer_worker_num > 1:
1949
+ origin_rid = get_origin_rid(origin_rid)
1874
1950
  state.finished = True
1875
1951
  if recv_obj.finished_reason:
1876
1952
  out = {
@@ -1883,7 +1959,7 @@ class TokenizerManager:
1883
1959
  out = {
1884
1960
  "text": "",
1885
1961
  "meta_info": {
1886
- "id": recv_obj.rid,
1962
+ "id": origin_rid,
1887
1963
  "finish_reason": {
1888
1964
  "type": "abort",
1889
1965
  "message": "Abort before prefill",
@@ -2069,6 +2145,8 @@ T = TypeVar("T")
2069
2145
  class _Communicator(Generic[T]):
2070
2146
  """Note: The communicator now only run up to 1 in-flight request at any time."""
2071
2147
 
2148
+ enable_multi_tokenizer = False
2149
+
2072
2150
  def __init__(self, sender, fan_out: int):
2073
2151
  self._sender = sender
2074
2152
  self._fan_out = fan_out
@@ -2085,6 +2163,8 @@ class _Communicator(Generic[T]):
2085
2163
  assert self._result_values is None
2086
2164
 
2087
2165
  if obj:
2166
+ if _Communicator.enable_multi_tokenizer:
2167
+ obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
2088
2168
  self._sender.send_pyobj(obj)
2089
2169
 
2090
2170
  self._result_event = asyncio.Event()
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
47
47
  self.req_to_token_pool.free(req.req_pool_idx)
48
48
  self.token_to_kv_pool_allocator.free(kv_indices)
49
49
 
50
- def cache_unfinished_req(self, req: Req):
50
+ def cache_unfinished_req(self, req: Req, chunked=False):
51
51
  kv_indices = self.req_to_token_pool.req_to_token[
52
52
  req.req_pool_idx, : len(req.fill_ids)
53
53
  ]