sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -69,6 +69,8 @@ from sglang.srt.managers.io_struct import (
69
69
  AbortReq,
70
70
  BatchTokenizedEmbeddingReqInput,
71
71
  BatchTokenizedGenerateReqInput,
72
+ ClearHiCacheReqInput,
73
+ ClearHiCacheReqOutput,
72
74
  CloseSessionReqInput,
73
75
  ExpertDistributionReq,
74
76
  ExpertDistributionReqOutput,
@@ -82,6 +84,8 @@ from sglang.srt.managers.io_struct import (
82
84
  InitWeightsUpdateGroupReqInput,
83
85
  LoadLoRAAdapterReqInput,
84
86
  LoadLoRAAdapterReqOutput,
87
+ MultiTokenizerRegisterReq,
88
+ MultiTokenizerWarpper,
85
89
  OpenSessionReqInput,
86
90
  OpenSessionReqOutput,
87
91
  ProfileReq,
@@ -255,7 +259,6 @@ class Scheduler(
255
259
  # Init inter-process communication
256
260
  context = zmq.Context(2)
257
261
  self.idle_sleeper = None
258
-
259
262
  if self.pp_rank == 0 and self.attn_tp_rank == 0:
260
263
  self.recv_from_tokenizer = get_zmq_socket(
261
264
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
@@ -515,6 +518,7 @@ class Scheduler(
515
518
  (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
516
519
  (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
517
520
  (FlushCacheReqInput, self.flush_cache_wrapped),
521
+ (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
518
522
  (AbortReq, self.abort_request),
519
523
  (OpenSessionReqInput, self.open_session),
520
524
  (CloseSessionReqInput, self.close_session),
@@ -537,6 +541,7 @@ class Scheduler(
537
541
  (ExpertDistributionReq, self.expert_distribution_handle),
538
542
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
539
543
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
544
+ (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
540
545
  ]
541
546
  )
542
547
 
@@ -1098,6 +1103,17 @@ class Scheduler(
1098
1103
  )
1099
1104
  self.send_to_tokenizer.send_pyobj(abort_req)
1100
1105
  continue
1106
+
1107
+ # If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
1108
+ if isinstance(recv_req, MultiTokenizerWarpper):
1109
+ worker_id = recv_req.worker_id
1110
+ recv_req = recv_req.obj
1111
+ output = self._request_dispatcher(recv_req)
1112
+ if output is not None:
1113
+ output = MultiTokenizerWarpper(worker_id, output)
1114
+ self.send_to_tokenizer.send_pyobj(output)
1115
+ continue
1116
+
1101
1117
  output = self._request_dispatcher(recv_req)
1102
1118
  if output is not None:
1103
1119
  if isinstance(output, RpcReqOutput):
@@ -1503,7 +1519,7 @@ class Scheduler(
1503
1519
  # Move the chunked request out of the batch so that we can merge
1504
1520
  # only finished requests to running_batch.
1505
1521
  chunked_req_to_exclude.add(self.chunked_req)
1506
- self.tree_cache.cache_unfinished_req(self.chunked_req)
1522
+ self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1507
1523
  # chunked request keeps its rid but will get a new req_pool_idx
1508
1524
  self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1509
1525
  if self.last_batch and self.last_batch.forward_mode.is_extend():
@@ -2207,6 +2223,16 @@ class Scheduler(
2207
2223
  success = self.flush_cache()
2208
2224
  return FlushCacheReqOutput(success=success)
2209
2225
 
2226
+ def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
2227
+ if self.enable_hierarchical_cache:
2228
+ self.tree_cache.clear_storage_backend()
2229
+ logger.info("Hierarchical cache cleared successfully!")
2230
+ if_success = True
2231
+ else:
2232
+ logging.warning("Hierarchical cache is not enabled.")
2233
+ if_success = False
2234
+ return ClearHiCacheReqOutput(success=if_success)
2235
+
2210
2236
  def flush_cache(self):
2211
2237
  """Flush the memory pool and cache."""
2212
2238
  if (
@@ -2377,7 +2403,14 @@ class Scheduler(
2377
2403
  # This only works for requests that have not started anything.
2378
2404
  # We still need to send something back to TokenizerManager to clean up the state.
2379
2405
  req = self.waiting_queue.pop(i)
2406
+ if self.enable_hicache_storage:
2407
+ # to release prefetch events associated with the request
2408
+ self.tree_cache.release_aborted_request(req.rid)
2380
2409
  self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2410
+ # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2411
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
2412
+ self.tree_cache.cache_finished_req(req)
2413
+
2381
2414
  logger.debug(f"Abort queued request. {req.rid=}")
2382
2415
 
2383
2416
  # Delete the requests in the grammar queue
@@ -2457,6 +2490,10 @@ class Scheduler(
2457
2490
  result = self.tp_worker.unload_lora_adapter(recv_req)
2458
2491
  return result
2459
2492
 
2493
+ def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
2494
+ self.send_to_detokenizer.send_pyobj(recv_req)
2495
+ return recv_req
2496
+
2460
2497
  def slow_down(self, recv_req: SlowDownReqInput):
2461
2498
  t = recv_req.forward_sleep_time
2462
2499
  if t is not None and t <= 0:
@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
93
93
  # This updates radix so others can match
94
94
  self.tree_cache.cache_unfinished_req(req)
95
95
 
96
- if req.return_logprob:
96
+ if batch.return_logprob:
97
97
  assert extend_logprob_start_len_per_req is not None
98
98
  assert extend_input_len_per_req is not None
99
99
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
100
100
  extend_input_len = extend_input_len_per_req[i]
101
101
  num_input_logprobs = extend_input_len - extend_logprob_start_len
102
- self.add_logprob_return_values(
103
- i,
104
- req,
105
- logprob_pt,
106
- next_token_ids,
107
- num_input_logprobs,
108
- logits_output,
109
- )
102
+ if req.return_logprob:
103
+ self.add_logprob_return_values(
104
+ i,
105
+ req,
106
+ logprob_pt,
107
+ next_token_ids,
108
+ num_input_logprobs,
109
+ logits_output,
110
+ )
110
111
  logprob_pt += num_input_logprobs
111
112
 
112
113
  if (
@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
146
147
  skip_stream_req = req
147
148
 
148
149
  # Incrementally update input logprobs.
149
- if req.return_logprob:
150
+ if batch.return_logprob:
150
151
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
151
152
  extend_input_len = extend_input_len_per_req[i]
152
153
  if extend_logprob_start_len < extend_input_len:
@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
154
155
  num_input_logprobs = (
155
156
  extend_input_len - extend_logprob_start_len
156
157
  )
157
- self.add_input_logprob_return_values(
158
- i,
159
- req,
160
- logits_output,
161
- logprob_pt,
162
- num_input_logprobs,
163
- last_prefill_chunk=False,
164
- )
158
+ if req.return_logprob:
159
+ self.add_input_logprob_return_values(
160
+ i,
161
+ req,
162
+ logits_output,
163
+ logprob_pt,
164
+ num_input_logprobs,
165
+ last_prefill_chunk=False,
166
+ )
165
167
  logprob_pt += num_input_logprobs
166
168
 
167
169
  self.set_next_batch_sampling_info_done(batch)
@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
121
121
  url = params["url"]
122
122
 
123
123
  worker = self.tp_worker.worker
124
-
125
124
  worker.model_runner.save_remote_model(url)
126
125
 
126
+ if self.draft_worker is not None:
127
+ draft_url = params.get("draft_url", None)
128
+ assert (
129
+ draft_url is not None
130
+ ), "draft_url must be provided when draft model is enabled"
131
+ draft_worker = self.draft_worker.worker
132
+ draft_worker.model_runner.save_remote_model(draft_url)
133
+
127
134
  def save_sharded_model(self, params):
128
135
  worker = self.tp_worker.worker
129
136
 
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
73
73
  BatchTokenIDOut,
74
74
  BatchTokenizedEmbeddingReqInput,
75
75
  BatchTokenizedGenerateReqInput,
76
+ ClearHiCacheReqInput,
77
+ ClearHiCacheReqOutput,
76
78
  CloseSessionReqInput,
77
79
  ConfigureLoggingReq,
78
80
  EmbeddingReqInput,
@@ -92,6 +94,7 @@ from sglang.srt.managers.io_struct import (
92
94
  LoadLoRAAdapterReqInput,
93
95
  LoadLoRAAdapterReqOutput,
94
96
  LoRAUpdateResult,
97
+ MultiTokenizerWarpper,
95
98
  OpenSessionReqInput,
96
99
  OpenSessionReqOutput,
97
100
  ProfileReq,
@@ -129,6 +132,7 @@ from sglang.srt.utils import (
129
132
  dataclass_to_string_truncated,
130
133
  freeze_gc,
131
134
  get_bool_env_var,
135
+ get_origin_rid,
132
136
  get_zmq_socket,
133
137
  kill_process_tree,
134
138
  )
@@ -264,9 +268,15 @@ class TokenizerManager:
264
268
  self.recv_from_detokenizer = get_zmq_socket(
265
269
  context, zmq.PULL, port_args.tokenizer_ipc_name, True
266
270
  )
267
- self.send_to_scheduler = get_zmq_socket(
268
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
269
- )
271
+ if self.server_args.tokenizer_worker_num > 1:
272
+ # Use tokenizer_worker_ipc_name in multi-tokenizer mode
273
+ self.send_to_scheduler = get_zmq_socket(
274
+ context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
275
+ )
276
+ else:
277
+ self.send_to_scheduler = get_zmq_socket(
278
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
279
+ )
270
280
 
271
281
  # Request states
272
282
  self.no_create_loop = False
@@ -310,35 +320,7 @@ class TokenizerManager:
310
320
  self.lora_update_lock = asyncio.Lock()
311
321
 
312
322
  # For PD disaggregtion
313
- self.disaggregation_mode = DisaggregationMode(
314
- self.server_args.disaggregation_mode
315
- )
316
- self.disaggregation_transfer_backend = TransferBackend(
317
- self.server_args.disaggregation_transfer_backend
318
- )
319
- # Start kv boostrap server on prefill
320
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
321
- # only start bootstrap server on prefill tm
322
- kv_bootstrap_server_class = get_kv_class(
323
- self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
324
- )
325
- self.bootstrap_server = kv_bootstrap_server_class(
326
- self.server_args.disaggregation_bootstrap_port
327
- )
328
- is_create_store = (
329
- self.server_args.node_rank == 0
330
- and self.server_args.disaggregation_transfer_backend == "ascend"
331
- )
332
- if is_create_store:
333
- try:
334
- from mf_adapter import create_config_store
335
-
336
- ascend_url = os.getenv("ASCEND_MF_STORE_URL")
337
- create_config_store(ascend_url)
338
- except Exception as e:
339
- error_message = f"Failed create mf store, invalid ascend_url."
340
- error_message += f" With exception {e}"
341
- raise error_message
323
+ self.init_disaggregation()
342
324
 
343
325
  # For load balancing
344
326
  self.current_load = 0
@@ -386,6 +368,9 @@ class TokenizerManager:
386
368
  self.flush_cache_communicator = _Communicator(
387
369
  self.send_to_scheduler, server_args.dp_size
388
370
  )
371
+ self.clear_hicache_storage_communicator = _Communicator(
372
+ self.send_to_scheduler, server_args.dp_size
373
+ )
389
374
  self.profile_communicator = _Communicator(
390
375
  self.send_to_scheduler, server_args.dp_size
391
376
  )
@@ -447,6 +432,10 @@ class TokenizerManager:
447
432
  SlowDownReqOutput,
448
433
  self.slow_down_communicator.handle_recv,
449
434
  ),
435
+ (
436
+ ClearHiCacheReqOutput,
437
+ self.clear_hicache_storage_communicator.handle_recv,
438
+ ),
450
439
  (
451
440
  FlushCacheReqOutput,
452
441
  self.flush_cache_communicator.handle_recv,
@@ -479,6 +468,37 @@ class TokenizerManager:
479
468
  ]
480
469
  )
481
470
 
471
+ def init_disaggregation(self):
472
+ self.disaggregation_mode = DisaggregationMode(
473
+ self.server_args.disaggregation_mode
474
+ )
475
+ self.disaggregation_transfer_backend = TransferBackend(
476
+ self.server_args.disaggregation_transfer_backend
477
+ )
478
+ # Start kv boostrap server on prefill
479
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
480
+ # only start bootstrap server on prefill tm
481
+ kv_bootstrap_server_class = get_kv_class(
482
+ self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
483
+ )
484
+ self.bootstrap_server = kv_bootstrap_server_class(
485
+ self.server_args.disaggregation_bootstrap_port
486
+ )
487
+ is_create_store = (
488
+ self.server_args.node_rank == 0
489
+ and self.server_args.disaggregation_transfer_backend == "ascend"
490
+ )
491
+ if is_create_store:
492
+ try:
493
+ from mf_adapter import create_config_store
494
+
495
+ ascend_url = os.getenv("ASCEND_MF_STORE_URL")
496
+ create_config_store(ascend_url)
497
+ except Exception as e:
498
+ error_message = f"Failed create mf store, invalid ascend_url."
499
+ error_message += f" With exception {e}"
500
+ raise error_message
501
+
482
502
  async def generate_request(
483
503
  self,
484
504
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -488,6 +508,15 @@ class TokenizerManager:
488
508
  self.auto_create_handle_loop()
489
509
  obj.normalize_batch_and_arguments()
490
510
 
511
+ if self.server_args.tokenizer_worker_num > 1:
512
+ # Modify rid, add worker_id
513
+ if isinstance(obj.rid, list):
514
+ # If it's an array, add worker_id prefix to each element
515
+ obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
516
+ else:
517
+ # If it's a single value, add worker_id prefix
518
+ obj.rid = f"{self.worker_id}_{obj.rid}"
519
+
491
520
  if self.log_requests:
492
521
  max_length, skip_names, _ = self.log_request_metadata
493
522
  logger.info(
@@ -988,6 +1017,13 @@ class TokenizerManager:
988
1017
  async def flush_cache(self) -> FlushCacheReqOutput:
989
1018
  return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
990
1019
 
1020
+ async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
1021
+ """Clear the hierarchical cache storage."""
1022
+ # Delegate to the scheduler to handle HiCacheStorage clearing
1023
+ return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
1024
+ 0
1025
+ ]
1026
+
991
1027
  def abort_request(self, rid: str = "", abort_all: bool = False):
992
1028
  if not abort_all and rid not in self.rid_to_state:
993
1029
  return
@@ -1080,6 +1116,8 @@ class TokenizerManager:
1080
1116
  async def _wait_for_model_update_from_disk(
1081
1117
  self, obj: UpdateWeightFromDiskReqInput
1082
1118
  ) -> Tuple[bool, str]:
1119
+ if self.server_args.tokenizer_worker_num > 1:
1120
+ obj = MultiTokenizerWarpper(self.worker_id, obj)
1083
1121
  self.send_to_scheduler.send_pyobj(obj)
1084
1122
  self.model_update_result = asyncio.Future()
1085
1123
  if self.server_args.dp_size == 1:
@@ -1299,6 +1337,8 @@ class TokenizerManager:
1299
1337
  elif obj.session_id in self.session_futures:
1300
1338
  return None
1301
1339
 
1340
+ if self.server_args.tokenizer_worker_num > 1:
1341
+ obj = MultiTokenizerWarpper(self.worker_id, obj)
1302
1342
  self.send_to_scheduler.send_pyobj(obj)
1303
1343
 
1304
1344
  self.session_futures[obj.session_id] = asyncio.Future()
@@ -1319,13 +1359,11 @@ class TokenizerManager:
1319
1359
  # Many DP ranks
1320
1360
  return [res.internal_state for res in responses]
1321
1361
 
1322
- async def set_internal_state(
1323
- self, obj: SetInternalStateReq
1324
- ) -> SetInternalStateReqOutput:
1362
+ async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
1325
1363
  responses: List[SetInternalStateReqOutput] = (
1326
1364
  await self.set_internal_state_communicator(obj)
1327
1365
  )
1328
- return [res.internal_state for res in responses]
1366
+ return [res.updated for res in responses]
1329
1367
 
1330
1368
  async def get_load(self) -> dict:
1331
1369
  # TODO(lsyin): fake load report server
@@ -1576,7 +1614,6 @@ class TokenizerManager:
1576
1614
 
1577
1615
  async def handle_loop(self):
1578
1616
  """The event loop that handles requests"""
1579
-
1580
1617
  while True:
1581
1618
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1582
1619
  self._result_dispatcher(recv_obj)
@@ -1596,9 +1633,12 @@ class TokenizerManager:
1596
1633
  )
1597
1634
  continue
1598
1635
 
1636
+ origin_rid = rid
1637
+ if self.server_args.tokenizer_worker_num > 1:
1638
+ origin_rid = get_origin_rid(rid)
1599
1639
  # Build meta_info and return value
1600
1640
  meta_info = {
1601
- "id": rid,
1641
+ "id": origin_rid,
1602
1642
  "finish_reason": recv_obj.finished_reasons[i],
1603
1643
  "prompt_tokens": recv_obj.prompt_tokens[i],
1604
1644
  "weight_version": self.server_args.weight_version,
@@ -1904,6 +1944,9 @@ class TokenizerManager:
1904
1944
  if is_health_check_generate_req(recv_obj):
1905
1945
  return
1906
1946
  state = self.rid_to_state[recv_obj.rid]
1947
+ origin_rid = recv_obj.rid
1948
+ if self.server_args.tokenizer_worker_num > 1:
1949
+ origin_rid = get_origin_rid(origin_rid)
1907
1950
  state.finished = True
1908
1951
  if recv_obj.finished_reason:
1909
1952
  out = {
@@ -1916,7 +1959,7 @@ class TokenizerManager:
1916
1959
  out = {
1917
1960
  "text": "",
1918
1961
  "meta_info": {
1919
- "id": recv_obj.rid,
1962
+ "id": origin_rid,
1920
1963
  "finish_reason": {
1921
1964
  "type": "abort",
1922
1965
  "message": "Abort before prefill",
@@ -2102,6 +2145,8 @@ T = TypeVar("T")
2102
2145
  class _Communicator(Generic[T]):
2103
2146
  """Note: The communicator now only run up to 1 in-flight request at any time."""
2104
2147
 
2148
+ enable_multi_tokenizer = False
2149
+
2105
2150
  def __init__(self, sender, fan_out: int):
2106
2151
  self._sender = sender
2107
2152
  self._fan_out = fan_out
@@ -2118,6 +2163,8 @@ class _Communicator(Generic[T]):
2118
2163
  assert self._result_values is None
2119
2164
 
2120
2165
  if obj:
2166
+ if _Communicator.enable_multi_tokenizer:
2167
+ obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
2121
2168
  self._sender.send_pyobj(obj)
2122
2169
 
2123
2170
  self._result_event = asyncio.Event()
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
47
47
  self.req_to_token_pool.free(req.req_pool_idx)
48
48
  self.token_to_kv_pool_allocator.free(kv_indices)
49
49
 
50
- def cache_unfinished_req(self, req: Req):
50
+ def cache_unfinished_req(self, req: Req, chunked=False):
51
51
  kv_indices = self.req_to_token_pool.req_to_token[
52
52
  req.req_pool_idx, : len(req.fill_ids)
53
53
  ]
@@ -102,6 +102,20 @@ class HiCacheStorage(ABC):
102
102
  """
103
103
  pass
104
104
 
105
+ @abstractmethod
106
+ def delete(self, key: str) -> bool:
107
+ """
108
+ Delete the entry associated with the given key.
109
+ """
110
+ pass
111
+
112
+ @abstractmethod
113
+ def clear(self) -> bool:
114
+ """
115
+ Clear all entries in the storage.
116
+ """
117
+ pass
118
+
105
119
  def batch_exists(self, keys: List[str]) -> int:
106
120
  """
107
121
  Check if the keys exist in the storage.
@@ -175,11 +189,12 @@ class HiCacheFile(HiCacheStorage):
175
189
  target_location: Optional[Any] = None,
176
190
  target_sizes: Optional[Any] = None,
177
191
  ) -> bool:
178
- key = self._get_suffixed_key(key)
179
- tensor_path = os.path.join(self.file_path, f"{key}.bin")
180
192
  if self.exists(key):
181
193
  logger.debug(f"Key {key} already exists. Skipped.")
182
194
  return True
195
+
196
+ key = self._get_suffixed_key(key)
197
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
183
198
  try:
184
199
  value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
185
200
  return True
@@ -213,12 +228,14 @@ class HiCacheFile(HiCacheStorage):
213
228
  logger.warning(f"Key {key} does not exist. Cannot delete.")
214
229
  return
215
230
 
216
- def clear(self) -> None:
231
+ def clear(self) -> bool:
217
232
  try:
218
233
  for filename in os.listdir(self.file_path):
219
234
  file_path = os.path.join(self.file_path, filename)
220
235
  if os.path.isfile(file_path):
221
236
  os.remove(file_path)
222
237
  logger.info("Cleared all entries in HiCacheFile storage.")
238
+ return True
223
239
  except Exception as e:
224
240
  logger.error(f"Failed to clear HiCacheFile storage: {e}")
241
+ return False