sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +0 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +62 -6
- sglang/srt/disaggregation/mini_lb.py +5 -1
- sglang/srt/disaggregation/mooncake/conn.py +32 -62
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/prefill.py +40 -4
- sglang/srt/disaggregation/utils.py +15 -0
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +114 -71
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -57
- sglang/srt/layers/quantization/fp8_utils.py +187 -262
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +3 -2
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +2 -4
- sglang/srt/managers/scheduler.py +12 -71
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +7 -2
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +20 -27
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +289 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +29 -201
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +34 -32
- sglang/srt/speculative/eagle_worker.py +4 -7
- sglang/srt/utils.py +16 -1
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -99,8 +99,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
99
99
|
disaggregation_mode: DisaggregationMode,
|
100
100
|
server_args: ServerArgs,
|
101
101
|
):
|
102
|
-
self.engine = MooncakeTransferEngine()
|
103
102
|
self.kv_args = args
|
103
|
+
self.engine = MooncakeTransferEngine(
|
104
|
+
hostname=get_local_ip_by_remote(),
|
105
|
+
gpu_id=self.kv_args.gpu_id,
|
106
|
+
ib_device=self.kv_args.ib_device,
|
107
|
+
)
|
104
108
|
self.disaggregation_mode = disaggregation_mode
|
105
109
|
# for p/d multi node infer
|
106
110
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
@@ -387,6 +391,10 @@ class MooncakeKVSender(BaseKVSender):
|
|
387
391
|
|
388
392
|
|
389
393
|
class MooncakeKVReceiver(BaseKVReceiver):
|
394
|
+
_ctx = zmq.Context()
|
395
|
+
_socket_cache = {}
|
396
|
+
_socket_locks = {}
|
397
|
+
_global_lock = threading.Lock()
|
390
398
|
|
391
399
|
def __init__(
|
392
400
|
self,
|
@@ -436,11 +444,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
436
444
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
437
445
|
return None
|
438
446
|
|
439
|
-
@
|
440
|
-
def _connect(
|
441
|
-
|
442
|
-
|
443
|
-
|
447
|
+
@classmethod
|
448
|
+
def _connect(cls, endpoint: str):
|
449
|
+
with cls._global_lock:
|
450
|
+
if endpoint not in cls._socket_cache:
|
451
|
+
sock = cls._ctx.socket(zmq.PUSH)
|
452
|
+
sock.connect(endpoint)
|
453
|
+
cls._socket_cache[endpoint] = sock
|
454
|
+
cls._socket_locks[endpoint] = threading.Lock()
|
455
|
+
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
444
456
|
|
445
457
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
446
458
|
self.prefill_server_url = (
|
@@ -456,18 +468,20 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
456
468
|
packed_aux_data_ptrs = b"".join(
|
457
469
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
458
470
|
)
|
459
|
-
self._connect("tcp://" + self.prefill_server_url)
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
472
|
+
with lock:
|
473
|
+
sock.send_multipart(
|
474
|
+
[
|
475
|
+
str(self.bootstrap_room).encode("ascii"),
|
476
|
+
get_local_ip_by_remote().encode("ascii"),
|
477
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
478
|
+
self.session_id.encode("ascii"),
|
479
|
+
packed_kv_data_ptrs,
|
480
|
+
kv_indices.tobytes(),
|
481
|
+
packed_aux_data_ptrs,
|
482
|
+
str(aux_index).encode("ascii"),
|
483
|
+
]
|
484
|
+
)
|
471
485
|
|
472
486
|
def poll(self) -> KVPoll:
|
473
487
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
@@ -493,52 +507,8 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
493
507
|
self.thread.start()
|
494
508
|
|
495
509
|
def _setup_routes(self):
|
496
|
-
self.app.router.add_route("*", "/metadata", self._handle_metadata)
|
497
510
|
self.app.router.add_route("*", "/route", self._handle_route)
|
498
511
|
|
499
|
-
async def _handle_metadata(self, request: web.Request):
|
500
|
-
key = request.query.get("key", "")
|
501
|
-
|
502
|
-
if request.method == "GET":
|
503
|
-
return await self._handle_metadata_get(key)
|
504
|
-
elif request.method == "PUT":
|
505
|
-
return await self._handle_metadata_put(key, request)
|
506
|
-
elif request.method == "DELETE":
|
507
|
-
return await self._handle_metadata_delete(key)
|
508
|
-
return web.Response(
|
509
|
-
text="Method not allowed", status=405, content_type="application/json"
|
510
|
-
)
|
511
|
-
|
512
|
-
async def _handle_metadata_get(self, key):
|
513
|
-
async with self.lock:
|
514
|
-
value = self.store.get(key)
|
515
|
-
if value is None:
|
516
|
-
return web.Response(
|
517
|
-
text="metadata not found", status=404, content_type="application/json"
|
518
|
-
)
|
519
|
-
return web.Response(body=value, status=200, content_type="application/json")
|
520
|
-
|
521
|
-
async def _handle_metadata_put(self, key, request):
|
522
|
-
data = await request.read()
|
523
|
-
async with self.lock:
|
524
|
-
self.store[key] = data
|
525
|
-
return web.Response(
|
526
|
-
text="metadata updated", status=200, content_type="application/json"
|
527
|
-
)
|
528
|
-
|
529
|
-
async def _handle_metadata_delete(self, key):
|
530
|
-
async with self.lock:
|
531
|
-
if key not in self.store:
|
532
|
-
return web.Response(
|
533
|
-
text="metadata not found",
|
534
|
-
status=404,
|
535
|
-
content_type="application/json",
|
536
|
-
)
|
537
|
-
del self.store[key]
|
538
|
-
return web.Response(
|
539
|
-
text="metadata deleted", status=200, content_type="application/json"
|
540
|
-
)
|
541
|
-
|
542
512
|
async def _handle_route(self, request: web.Request):
|
543
513
|
method = request.method
|
544
514
|
if method == "PUT":
|
@@ -1,45 +1,14 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
|
-
import os
|
4
|
-
import uuid
|
5
3
|
from dataclasses import dataclass
|
4
|
+
from typing import Optional
|
6
5
|
|
7
6
|
logger = logging.getLogger(__name__)
|
8
7
|
|
9
8
|
|
10
|
-
@dataclass
|
11
|
-
class MooncakeTransferEngineConfig:
|
12
|
-
local_hostname: str
|
13
|
-
metadata_server: str
|
14
|
-
protocol: str
|
15
|
-
device_name: str
|
16
|
-
|
17
|
-
@staticmethod
|
18
|
-
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
|
19
|
-
"""Load the config from a JSON file."""
|
20
|
-
with open(file_path) as fin:
|
21
|
-
config = json.load(fin)
|
22
|
-
return MooncakeTransferEngineConfig(
|
23
|
-
local_hostname=config.get("local_hostname", None),
|
24
|
-
metadata_server=config.get("metadata_server"),
|
25
|
-
protocol=config.get("protocol", "rdma"),
|
26
|
-
device_name=config.get("device_name", ""),
|
27
|
-
)
|
28
|
-
|
29
|
-
@staticmethod
|
30
|
-
def load_from_env() -> "MooncakeTransferEngineConfig":
|
31
|
-
"""Load config from a file specified in the environment variable."""
|
32
|
-
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
33
|
-
if config_file_path is None:
|
34
|
-
raise ValueError(
|
35
|
-
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
36
|
-
)
|
37
|
-
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
38
|
-
|
39
|
-
|
40
9
|
class MooncakeTransferEngine:
|
41
10
|
|
42
|
-
def __init__(self):
|
11
|
+
def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
|
43
12
|
try:
|
44
13
|
from mooncake.engine import TransferEngine
|
45
14
|
except ImportError as e:
|
@@ -50,43 +19,43 @@ class MooncakeTransferEngine:
|
|
50
19
|
) from e
|
51
20
|
|
52
21
|
self.engine = TransferEngine()
|
22
|
+
self.hostname = hostname
|
23
|
+
self.gpu_id = gpu_id
|
24
|
+
self.ib_device = ib_device
|
53
25
|
|
54
|
-
try:
|
55
|
-
self.config = MooncakeTransferEngineConfig.load_from_env()
|
56
|
-
logger.info("Mooncake Configuration loaded successfully.")
|
57
|
-
except ValueError as e:
|
58
|
-
logger.error(e)
|
59
|
-
raise
|
60
|
-
except Exception as exc:
|
61
|
-
logger.error("An error occurred while loading the configuration: %s", exc)
|
62
|
-
raise
|
63
|
-
|
64
|
-
self.config = MooncakeTransferEngineConfig.load_from_env()
|
65
|
-
|
66
|
-
session_suffix = "_" + str(uuid.uuid4())
|
67
|
-
self.session_id = self.config.local_hostname + session_suffix
|
68
26
|
self.initialize(
|
69
|
-
self.
|
70
|
-
self.
|
71
|
-
self.config.protocol,
|
72
|
-
self.config.device_name,
|
27
|
+
hostname=self.hostname,
|
28
|
+
device_name=self.ib_device,
|
73
29
|
)
|
30
|
+
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
74
31
|
|
75
32
|
def register(self, ptr, length):
|
76
|
-
self.engine.register_memory(ptr, length)
|
33
|
+
ret_value = self.engine.register_memory(ptr, length)
|
34
|
+
if ret_value != 0:
|
35
|
+
logger.error("Mooncake memory registration failed.")
|
36
|
+
raise RuntimeError("Mooncake memory registration failed.")
|
77
37
|
|
78
38
|
def deregister(self, ptr):
|
79
|
-
self.engine.unregister_memory(ptr)
|
39
|
+
ret_value = self.engine.unregister_memory(ptr)
|
40
|
+
if ret_value != 0:
|
41
|
+
logger.error("Mooncake memory deregistration failed.")
|
42
|
+
raise RuntimeError("Mooncake memory deregistration failed.")
|
80
43
|
|
81
44
|
def initialize(
|
82
45
|
self,
|
83
|
-
|
84
|
-
|
85
|
-
protocol: str,
|
86
|
-
device_name: str,
|
46
|
+
hostname: str,
|
47
|
+
device_name: Optional[str],
|
87
48
|
) -> None:
|
88
49
|
"""Initialize the mooncake instance."""
|
89
|
-
self.engine.initialize(
|
50
|
+
ret_value = self.engine.initialize(
|
51
|
+
hostname,
|
52
|
+
"P2PHANDSHAKE",
|
53
|
+
"rdma",
|
54
|
+
device_name if device_name is not None else "",
|
55
|
+
)
|
56
|
+
if ret_value != 0:
|
57
|
+
logger.error("Mooncake Transfer Engine initialization failed.")
|
58
|
+
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
90
59
|
|
91
60
|
def transfer_sync(
|
92
61
|
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
@@ -97,12 +66,12 @@ class MooncakeTransferEngine:
|
|
97
66
|
session_id, buffer, peer_buffer_address, length
|
98
67
|
)
|
99
68
|
if ret < 0:
|
100
|
-
logger.error("Transfer Return Error")
|
101
|
-
raise
|
69
|
+
logger.error("Mooncake Transfer Engine Return Error.")
|
70
|
+
raise RuntimeError("Mooncake Transfer Engine Return Error.")
|
102
71
|
return ret
|
103
72
|
|
104
73
|
def get_localhost(self):
|
105
|
-
return self.
|
74
|
+
return self.hostname
|
106
75
|
|
107
76
|
def get_session_id(self):
|
108
77
|
return self.session_id
|
@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import (
|
|
31
31
|
ReqToMetadataIdxAllocator,
|
32
32
|
TransferBackend,
|
33
33
|
get_kv_class,
|
34
|
+
kv_to_page_indices,
|
35
|
+
kv_to_page_num,
|
34
36
|
poll_and_all_reduce,
|
35
37
|
)
|
36
38
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
@@ -103,7 +105,7 @@ class PrefillBootstrapQueue:
|
|
103
105
|
kv_args.aux_item_lens = [
|
104
106
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
105
107
|
]
|
106
|
-
kv_args.ib_device =
|
108
|
+
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
107
109
|
kv_args.gpu_id = self.scheduler.gpu_id
|
108
110
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
109
111
|
kv_manager = kv_manager_class(
|
@@ -154,7 +156,8 @@ class PrefillBootstrapQueue:
|
|
154
156
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
155
157
|
)
|
156
158
|
assert req.metadata_buffer_index is not None
|
157
|
-
|
159
|
+
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
160
|
+
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
158
161
|
|
159
162
|
bootstrapped_reqs.append(req)
|
160
163
|
indices_to_remove.add(i)
|
@@ -171,6 +174,36 @@ class SchedulerDisaggregationPrefillMixin:
|
|
171
174
|
Mixin for Scheduler to handle disaggregation prefill
|
172
175
|
"""
|
173
176
|
|
177
|
+
@torch.no_grad()
|
178
|
+
def event_loop_normal_disagg_prefill(self):
|
179
|
+
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
180
|
+
|
181
|
+
while True:
|
182
|
+
recv_reqs = self.recv_requests()
|
183
|
+
self.process_input_requests(recv_reqs)
|
184
|
+
self.waiting_queue.extend(
|
185
|
+
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
186
|
+
)
|
187
|
+
self.process_prefill_chunk()
|
188
|
+
batch = self.get_new_batch_prefill()
|
189
|
+
self.cur_batch = batch
|
190
|
+
|
191
|
+
if batch:
|
192
|
+
result = self.run_batch(batch)
|
193
|
+
self.process_batch_result_disagg_prefill(batch, result)
|
194
|
+
|
195
|
+
if len(self.disagg_prefill_inflight_queue) > 0:
|
196
|
+
self.process_disagg_prefill_inflight_queue()
|
197
|
+
|
198
|
+
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
199
|
+
self.check_memory()
|
200
|
+
self.new_token_ratio = self.init_new_token_ratio
|
201
|
+
|
202
|
+
self.last_batch = batch
|
203
|
+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
204
|
+
# Otherwise, it hangs under high concurrency
|
205
|
+
self.running_batch.batch_is_full = False
|
206
|
+
|
174
207
|
def process_batch_result_disagg_prefill(
|
175
208
|
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
176
209
|
) -> None:
|
@@ -210,7 +243,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
210
243
|
|
211
244
|
polls = poll_and_all_reduce(
|
212
245
|
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
|
213
|
-
self.
|
246
|
+
self.attn_tp_cpu_group,
|
214
247
|
)
|
215
248
|
|
216
249
|
undone_reqs: List[Req] = []
|
@@ -270,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
270
303
|
req.metadata_buffer_index, token_id
|
271
304
|
)
|
272
305
|
is_last = token_id is not None
|
273
|
-
|
306
|
+
page_indices = kv_to_page_indices(
|
307
|
+
kv_indices, self.token_to_kv_pool_allocator.page_size
|
308
|
+
)
|
309
|
+
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
|
@@ -4,6 +4,7 @@ from collections import deque
|
|
4
4
|
from enum import Enum
|
5
5
|
from typing import List
|
6
6
|
|
7
|
+
import numpy as np
|
7
8
|
import torch
|
8
9
|
import torch.distributed as dist
|
9
10
|
|
@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
73
74
|
}
|
74
75
|
return class_mapping.get(class_type)
|
75
76
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
77
|
+
|
78
|
+
|
79
|
+
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
80
|
+
# 1. The page is guaruanteed to be full except the last page.
|
81
|
+
# 2. page index = kv_index // page_size
|
82
|
+
# The return vector is kv_indices[::page_size] // page_size
|
83
|
+
if page_size == 1: # shortcut
|
84
|
+
return kv_indices
|
85
|
+
return kv_indices[::page_size] // page_size
|
86
|
+
|
87
|
+
|
88
|
+
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
89
|
+
# ceil(num_kv_indices / page_size)
|
90
|
+
return (num_kv_indices + page_size - 1) // page_size
|
@@ -12,18 +12,17 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
import os
|
15
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
15
|
+
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
18
|
import torch.distributed as dist
|
19
19
|
from PIL.Image import Image
|
20
20
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
21
21
|
|
22
|
+
from sglang.srt.entrypoints.engine import Engine
|
22
23
|
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
|
23
24
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
24
25
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
25
|
-
from sglang.srt.server import Engine
|
26
|
-
from sglang.srt.server_args import PortArgs, ServerArgs
|
27
26
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
28
27
|
|
29
28
|
|
@@ -125,7 +124,7 @@ class VerlEngine:
|
|
125
124
|
|
126
125
|
def update_weights_from_tensor(
|
127
126
|
self,
|
128
|
-
named_tensors:
|
127
|
+
named_tensors: Iterable[Tuple[str, torch.Tensor]],
|
129
128
|
load_format: Optional[str] = None,
|
130
129
|
):
|
131
130
|
# Most naive implementation, can optimize a lot if it is bottleneck
|
@@ -154,9 +153,12 @@ class VerlEngine:
|
|
154
153
|
)
|
155
154
|
],
|
156
155
|
load_format=load_format,
|
157
|
-
flush_cache=
|
156
|
+
flush_cache=False,
|
158
157
|
)
|
159
158
|
|
159
|
+
if self._tp_rank == 0:
|
160
|
+
self._engine.tokenizer_manager.flush_cache()
|
161
|
+
|
160
162
|
def release_memory_occupation(self):
|
161
163
|
if self._tp_rank == 0:
|
162
164
|
self._engine.release_memory_occupation()
|
sglang/srt/layers/activation.py
CHANGED
@@ -21,13 +21,6 @@ import torch
|
|
21
21
|
import torch.nn as nn
|
22
22
|
import torch.nn.functional as F
|
23
23
|
|
24
|
-
from sglang.srt.utils import is_cuda_available
|
25
|
-
|
26
|
-
_is_cuda = is_cuda_available()
|
27
|
-
|
28
|
-
if _is_cuda:
|
29
|
-
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
30
|
-
|
31
24
|
from sglang.srt.custom_op import CustomOp
|
32
25
|
from sglang.srt.distributed import (
|
33
26
|
divide,
|
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
|
|
35
28
|
get_tensor_model_parallel_world_size,
|
36
29
|
)
|
37
30
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
-
from sglang.srt.utils import set_weight_attrs
|
31
|
+
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
32
|
+
|
33
|
+
_is_cuda = is_cuda_available()
|
34
|
+
|
35
|
+
if _is_cuda:
|
36
|
+
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
39
37
|
|
40
38
|
logger = logging.getLogger(__name__)
|
41
39
|
|