sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import (
|
|
35
35
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
36
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
37
37
|
from sglang.srt.server_args import ServerArgs
|
38
|
-
from sglang.srt.utils import
|
38
|
+
from sglang.srt.utils import (
|
39
|
+
format_tcp_address,
|
40
|
+
get_free_port,
|
41
|
+
get_int_env_var,
|
42
|
+
get_ip,
|
43
|
+
get_local_ip_auto,
|
44
|
+
is_valid_ipv6_address,
|
45
|
+
maybe_wrap_ipv6_address,
|
46
|
+
)
|
39
47
|
|
40
48
|
logger = logging.getLogger(__name__)
|
41
49
|
|
@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
148
156
|
self.request_status: Dict[int, KVPoll] = {}
|
149
157
|
self.rank_port = None
|
150
158
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
159
|
+
if is_valid_ipv6_address(self.local_ip):
|
160
|
+
self.server_socket.setsockopt(zmq.IPV6, 1)
|
161
|
+
|
151
162
|
self.register_buffer_to_engine()
|
152
163
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
153
164
|
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager):
|
|
240
251
|
self.engine.register(aux_data_ptr, aux_data_len)
|
241
252
|
|
242
253
|
@cache
|
243
|
-
def _connect(self, endpoint: str):
|
254
|
+
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
244
255
|
socket = zmq.Context().socket(zmq.PUSH)
|
256
|
+
if is_ipv6:
|
257
|
+
socket.setsockopt(zmq.IPV6, 1)
|
245
258
|
socket.connect(endpoint)
|
246
259
|
return socket
|
247
260
|
|
@@ -321,67 +334,60 @@ class MooncakeKVManager(BaseKVManager):
|
|
321
334
|
This may introduce performance overhead (increased TTFT) for long sequences.
|
322
335
|
"""
|
323
336
|
# Extract configuration
|
324
|
-
local_tp_rank = self.kv_args.engine_rank
|
325
337
|
local_tp_size = self.tp_size // self.dp_size
|
338
|
+
local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
|
339
|
+
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
340
|
+
dst_tp_rank_in_group = dst_tp_rank % dst_tp_size
|
326
341
|
num_kv_heads = self.kv_args.kv_head_num
|
327
342
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
328
343
|
page_size = self.kv_args.page_size
|
329
344
|
|
330
345
|
# Calculate head distribution
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
|
346
|
+
src_heads_per_rank = num_kv_heads
|
347
|
+
dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size
|
348
|
+
bytes_per_head_slice_to_send = (
|
349
|
+
dst_kv_item_len // page_size // dst_heads_per_rank
|
350
|
+
)
|
338
351
|
|
339
352
|
# Determine slicing parameters based on TP configuration
|
340
353
|
if local_tp_size > dst_tp_size:
|
341
|
-
|
342
|
-
|
343
|
-
|
354
|
+
# Send KVCache from multiple prefill instances to 1 decode instance
|
355
|
+
src_head_start_offset = 0
|
356
|
+
num_heads_to_send = src_heads_per_rank
|
357
|
+
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
344
358
|
else:
|
345
|
-
|
346
|
-
|
347
|
-
|
359
|
+
# Send KVCache from 1 prefill instance to multiple decode instances
|
360
|
+
src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
|
361
|
+
num_heads_to_send = dst_heads_per_rank
|
362
|
+
dst_head_start_offset = 0
|
348
363
|
|
349
|
-
|
364
|
+
layers_params = []
|
350
365
|
for layer_id in range(num_layers):
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
logger.error(
|
358
|
-
f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
|
359
|
-
)
|
360
|
-
return -1
|
361
|
-
|
362
|
-
# Calculate precise byte offset and length for the sub-slice within the prefill page data
|
363
|
-
src_slice_offset = src_head_offset * bytes_per_head
|
364
|
-
dst_slice_offset = dst_head_offset * bytes_per_head
|
365
|
-
slice_lens_per_page = num_heads_to_send * bytes_per_head
|
366
|
+
# Calculate precise byte offset and length for the sub-slice within the token
|
367
|
+
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
368
|
+
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
369
|
+
heads_bytes_per_token_to_send = (
|
370
|
+
num_heads_to_send * bytes_per_head_slice_to_send
|
371
|
+
)
|
366
372
|
|
367
|
-
# Sanity check: The data sub-slice to be sent should fit into the
|
368
|
-
# This means
|
369
|
-
if
|
373
|
+
# Sanity check: The data sub-slice to be sent should fit into the dst buffer.
|
374
|
+
# This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
|
375
|
+
if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
|
370
376
|
logger.error(
|
371
377
|
f"[{mooncake_session_id}] Layer {layer_id}: "
|
372
|
-
f"slice size ({
|
373
|
-
f"target
|
378
|
+
f"slice size ({heads_bytes_per_token_to_send}) exceeds "
|
379
|
+
f"target token slot size ({dst_kv_item_len // page_size})"
|
374
380
|
)
|
375
381
|
return -1
|
376
|
-
|
382
|
+
layers_params.append(
|
377
383
|
(
|
378
384
|
self.kv_args.kv_data_ptrs[layer_id],
|
379
385
|
dst_kv_ptrs[layer_id],
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
386
|
+
src_kv_item_len,
|
387
|
+
dst_kv_item_len,
|
388
|
+
src_head_slice_offset,
|
389
|
+
dst_head_slice_offset,
|
390
|
+
heads_bytes_per_token_to_send,
|
385
391
|
)
|
386
392
|
)
|
387
393
|
|
@@ -391,9 +397,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
391
397
|
dst_ptr,
|
392
398
|
src_item_len,
|
393
399
|
dst_item_len,
|
394
|
-
|
395
|
-
|
396
|
-
|
400
|
+
src_head_slice_offset,
|
401
|
+
dst_head_slice_offset,
|
402
|
+
heads_bytes_per_token_to_send,
|
397
403
|
) = layer_params
|
398
404
|
src_addr_list = []
|
399
405
|
dst_addr_list = []
|
@@ -424,17 +430,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
424
430
|
)
|
425
431
|
|
426
432
|
# Calculate final src and dst addresses by applying head-slice offsets
|
427
|
-
src_slice_addr = src_token_slot_start_addr +
|
428
|
-
dst_slice_addr = dst_token_slot_start_addr +
|
433
|
+
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
|
434
|
+
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
|
429
435
|
|
430
436
|
src_addr_list.append(src_slice_addr)
|
431
437
|
dst_addr_list.append(dst_slice_addr)
|
432
|
-
length_list.append(
|
433
|
-
|
434
|
-
logger.debug(
|
435
|
-
f"SYNC: sid={mooncake_session_id}, "
|
436
|
-
f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
|
437
|
-
)
|
438
|
+
length_list.append(heads_bytes_per_token_to_send)
|
438
439
|
|
439
440
|
return self.engine.batch_transfer_sync(
|
440
441
|
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
@@ -445,7 +446,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
445
446
|
process_layer_tp_aware,
|
446
447
|
layer_params,
|
447
448
|
)
|
448
|
-
for layer_params in
|
449
|
+
for layer_params in layers_params
|
449
450
|
]
|
450
451
|
|
451
452
|
for future in concurrent.futures.as_completed(futures):
|
@@ -483,9 +484,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
483
484
|
def sync_status_to_decode_endpoint(
|
484
485
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
485
486
|
):
|
486
|
-
|
487
|
-
remote =
|
488
|
-
|
487
|
+
self._connect(
|
488
|
+
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
|
489
|
+
).send_multipart(
|
489
490
|
[
|
490
491
|
str(room).encode("ascii"),
|
491
492
|
str(status).encode("ascii"),
|
@@ -533,12 +534,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
533
534
|
if len(chunked_dst_kv_indice) < len(
|
534
535
|
kv_chunk.prefill_kv_indices
|
535
536
|
):
|
536
|
-
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
|
537
|
-
: len(chunked_dst_kv_indice)
|
538
|
-
]
|
539
537
|
logger.warning(
|
540
538
|
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
541
539
|
)
|
540
|
+
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
|
541
|
+
: len(chunked_dst_kv_indice)
|
542
|
+
]
|
542
543
|
|
543
544
|
target_rank_registration_info: KVArgsRegisterInfo = (
|
544
545
|
self.decode_kv_args_table[req.mooncake_session_id]
|
@@ -628,9 +629,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
628
629
|
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
|
629
630
|
)
|
630
631
|
|
632
|
+
def _bind_server_socket(self):
|
633
|
+
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
634
|
+
|
631
635
|
def start_prefill_thread(self):
|
632
636
|
self.rank_port = get_free_port()
|
633
|
-
self.
|
637
|
+
self._bind_server_socket()
|
634
638
|
|
635
639
|
def bootstrap_thread():
|
636
640
|
"""This thread recvs pre-alloc notification from the decode engine"""
|
@@ -669,7 +673,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
669
673
|
|
670
674
|
def start_decode_thread(self):
|
671
675
|
self.rank_port = get_free_port()
|
672
|
-
self.
|
676
|
+
self._bind_server_socket()
|
673
677
|
|
674
678
|
def decode_thread():
|
675
679
|
while True:
|
@@ -788,7 +792,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
788
792
|
# requests with the same dst_sessions will be added into the same
|
789
793
|
# queue, which enables early abort with failed sessions.
|
790
794
|
dst_infos = self.transfer_infos[bootstrap_room].keys()
|
791
|
-
session_port_sum = sum(int(session.
|
795
|
+
session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
|
792
796
|
shard_idx = session_port_sum % len(self.transfer_queues)
|
793
797
|
|
794
798
|
self.transfer_queues[shard_idx].put(
|
@@ -826,11 +830,18 @@ class MooncakeKVManager(BaseKVManager):
|
|
826
830
|
def _register_to_bootstrap(self):
|
827
831
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
828
832
|
if self.dist_init_addr:
|
829
|
-
|
833
|
+
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
834
|
+
if self.dist_init_addr.endswith("]"):
|
835
|
+
host = self.dist_init_addr
|
836
|
+
else:
|
837
|
+
host, _ = self.dist_init_addr.rsplit(":", 1)
|
838
|
+
else:
|
839
|
+
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
830
840
|
else:
|
831
|
-
|
841
|
+
host = get_ip()
|
842
|
+
host = maybe_wrap_ipv6_address(host)
|
832
843
|
|
833
|
-
bootstrap_server_url = f"{
|
844
|
+
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
834
845
|
url = f"http://{bootstrap_server_url}/route"
|
835
846
|
payload = {
|
836
847
|
"role": "Prefill",
|
@@ -1175,9 +1186,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1175
1186
|
|
1176
1187
|
def _register_kv_args(self):
|
1177
1188
|
for bootstrap_info in self.bootstrap_infos:
|
1178
|
-
self.prefill_server_url = (
|
1179
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
1180
|
-
)
|
1181
1189
|
packed_kv_data_ptrs = b"".join(
|
1182
1190
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
1183
1191
|
)
|
@@ -1191,7 +1199,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1191
1199
|
dst_tp_size = str(tp_size).encode("ascii")
|
1192
1200
|
dst_kv_item_len = str(kv_item_len).encode("ascii")
|
1193
1201
|
|
1194
|
-
sock, lock = self.
|
1202
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
1195
1203
|
with lock:
|
1196
1204
|
sock.send_multipart(
|
1197
1205
|
[
|
@@ -1208,23 +1216,32 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1208
1216
|
)
|
1209
1217
|
|
1210
1218
|
@classmethod
|
1211
|
-
def _connect(cls, endpoint: str):
|
1219
|
+
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
1212
1220
|
with cls._global_lock:
|
1213
1221
|
if endpoint not in cls._socket_cache:
|
1214
1222
|
sock = cls._ctx.socket(zmq.PUSH)
|
1223
|
+
if is_ipv6:
|
1224
|
+
sock.setsockopt(zmq.IPV6, 1)
|
1215
1225
|
sock.connect(endpoint)
|
1216
1226
|
cls._socket_cache[endpoint] = sock
|
1217
1227
|
cls._socket_locks[endpoint] = threading.Lock()
|
1218
1228
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
1219
1229
|
|
1230
|
+
@classmethod
|
1231
|
+
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
1232
|
+
ip_address = bootstrap_info["rank_ip"]
|
1233
|
+
port = bootstrap_info["rank_port"]
|
1234
|
+
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
1235
|
+
sock, lock = cls._connect(
|
1236
|
+
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
1237
|
+
)
|
1238
|
+
return sock, lock
|
1239
|
+
|
1220
1240
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
1221
1241
|
for bootstrap_info in self.bootstrap_infos:
|
1222
|
-
|
1223
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
1224
|
-
)
|
1242
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
1225
1243
|
is_dummy = bootstrap_info["is_dummy"]
|
1226
1244
|
|
1227
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
1228
1245
|
with lock:
|
1229
1246
|
sock.send_multipart(
|
1230
1247
|
[
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
from sglang.srt.utils import get_bool_env_var, get_free_port
|
4
|
+
from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
|
5
5
|
|
6
6
|
logger = logging.getLogger(__name__)
|
7
7
|
|
@@ -27,7 +27,9 @@ class MooncakeTransferEngine:
|
|
27
27
|
hostname=self.hostname,
|
28
28
|
device_name=self.ib_device,
|
29
29
|
)
|
30
|
-
self.session_id =
|
30
|
+
self.session_id = (
|
31
|
+
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
|
32
|
+
)
|
31
33
|
|
32
34
|
def register(self, ptr, length):
|
33
35
|
try:
|
@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import (
|
|
27
27
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
28
28
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
29
29
|
from sglang.srt.server_args import ServerArgs
|
30
|
-
from sglang.srt.utils import
|
30
|
+
from sglang.srt.utils import (
|
31
|
+
format_tcp_address,
|
32
|
+
get_local_ip_auto,
|
33
|
+
is_valid_ipv6_address,
|
34
|
+
)
|
31
35
|
|
32
36
|
logger = logging.getLogger(__name__)
|
33
37
|
|
@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager):
|
|
124
128
|
"to run SGLang with NixlTransferEngine."
|
125
129
|
) from e
|
126
130
|
self.agent = nixl_agent(str(uuid.uuid4()))
|
131
|
+
self.local_ip = get_local_ip_auto()
|
127
132
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
133
|
+
if is_valid_ipv6_address(self.local_ip):
|
134
|
+
self.server_socket.setsockopt(zmq.IPV6, 1)
|
128
135
|
self.register_buffer_to_engine()
|
129
136
|
|
130
137
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager):
|
|
337
344
|
return False
|
338
345
|
return self.transfer_statuses[room].is_done()
|
339
346
|
|
347
|
+
def _bind_server_socket(self):
|
348
|
+
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
349
|
+
|
340
350
|
def _start_bootstrap_thread(self):
|
341
|
-
self.
|
351
|
+
self._bind_server_socket()
|
342
352
|
|
343
353
|
def bootstrap_thread():
|
344
354
|
"""This thread recvs transfer info from the decode engine"""
|
@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
452
462
|
|
453
463
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
454
464
|
for bootstrap_info in self.bootstrap_infos:
|
455
|
-
self.prefill_server_url = (
|
456
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
457
|
-
)
|
458
465
|
logger.debug(
|
459
466
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
460
467
|
)
|
468
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
461
469
|
is_dummy = bootstrap_info["is_dummy"]
|
462
470
|
logger.debug(
|
463
|
-
f"Sending to
|
471
|
+
f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
464
472
|
)
|
465
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
466
473
|
with lock:
|
467
474
|
sock.send_multipart(
|
468
475
|
[
|
469
476
|
GUARD,
|
470
477
|
str(self.bootstrap_room).encode("ascii"),
|
471
|
-
|
478
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
472
479
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
473
480
|
self.kv_mgr.agent.name.encode("ascii"),
|
474
481
|
kv_indices.tobytes() if not is_dummy else b"",
|
@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
494
501
|
|
495
502
|
def _register_kv_args(self):
|
496
503
|
for bootstrap_info in self.bootstrap_infos:
|
497
|
-
|
498
|
-
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
499
|
-
)
|
504
|
+
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
500
505
|
packed_kv_data_ptrs = b"".join(
|
501
506
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
502
507
|
)
|
@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
504
509
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
505
510
|
)
|
506
511
|
|
507
|
-
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
508
512
|
with lock:
|
509
513
|
sock.send_multipart(
|
510
514
|
[
|
511
515
|
GUARD,
|
512
516
|
"None".encode("ascii"),
|
513
|
-
|
517
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
514
518
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
515
519
|
self.kv_mgr.agent.name.encode("ascii"),
|
516
520
|
self.kv_mgr.agent.get_agent_metadata(),
|
@@ -4,18 +4,18 @@ import ctypes
|
|
4
4
|
import logging
|
5
5
|
import os
|
6
6
|
from contextlib import contextmanager
|
7
|
-
from
|
8
|
-
from typing import Any, Callable, List, Optional, TypeVar, Union
|
7
|
+
from typing import Any, List, Optional, Union
|
9
8
|
|
10
9
|
import torch
|
11
10
|
import torch.distributed as dist
|
12
11
|
from torch.distributed import ProcessGroup
|
13
|
-
from typing_extensions import ParamSpec
|
14
12
|
|
15
13
|
from sglang.srt import _custom_ops as ops
|
16
14
|
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
17
15
|
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
18
16
|
gpu_p2p_access_check,
|
17
|
+
is_full_nvlink,
|
18
|
+
is_weak_contiguous,
|
19
19
|
)
|
20
20
|
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
21
21
|
from sglang.srt.utils import is_cuda, is_hip
|
@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__)
|
|
25
25
|
_is_cuda = is_cuda()
|
26
26
|
_is_hip = is_hip()
|
27
27
|
|
28
|
-
if _is_cuda:
|
29
|
-
try:
|
30
|
-
import pynvml
|
31
|
-
except ImportError as e:
|
32
|
-
logger.warning("Failed to import pynvml with %r", e)
|
33
|
-
|
34
|
-
if _is_hip:
|
35
|
-
try:
|
36
|
-
from amdsmi import (
|
37
|
-
AmdSmiException,
|
38
|
-
amdsmi_get_processor_handles,
|
39
|
-
amdsmi_init,
|
40
|
-
amdsmi_shut_down,
|
41
|
-
amdsmi_topo_get_link_type,
|
42
|
-
)
|
43
|
-
except ImportError as e:
|
44
|
-
logger.warning("Failed to import amdsmi with %r", e)
|
45
28
|
|
46
29
|
try:
|
47
30
|
if ops.use_vllm_custom_allreduce and not _is_hip:
|
@@ -57,70 +40,6 @@ except Exception:
|
|
57
40
|
|
58
41
|
logger = logging.getLogger(__name__)
|
59
42
|
|
60
|
-
_P = ParamSpec("_P")
|
61
|
-
_R = TypeVar("_R")
|
62
|
-
|
63
|
-
|
64
|
-
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
65
|
-
@wraps(fn)
|
66
|
-
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
67
|
-
if _is_hip:
|
68
|
-
try:
|
69
|
-
amdsmi_init()
|
70
|
-
return fn(*args, **kwargs)
|
71
|
-
finally:
|
72
|
-
amdsmi_shut_down()
|
73
|
-
else:
|
74
|
-
pynvml.nvmlInit()
|
75
|
-
try:
|
76
|
-
return fn(*args, **kwargs)
|
77
|
-
finally:
|
78
|
-
pynvml.nvmlShutdown()
|
79
|
-
|
80
|
-
return wrapper
|
81
|
-
|
82
|
-
|
83
|
-
@with_nvml_context
|
84
|
-
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
85
|
-
if _is_hip:
|
86
|
-
"""
|
87
|
-
query if the set of gpus are fully connected by xgmi (1 hop)
|
88
|
-
"""
|
89
|
-
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
90
|
-
for i, handle in enumerate(handles):
|
91
|
-
for j, peer_handle in enumerate(handles):
|
92
|
-
if i < j:
|
93
|
-
try:
|
94
|
-
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
95
|
-
# type is 2 for XGMI
|
96
|
-
if link_type["hops"] != 1 or link_type["type"] != 2:
|
97
|
-
return False
|
98
|
-
except AmdSmiException as error:
|
99
|
-
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
100
|
-
return False
|
101
|
-
return True
|
102
|
-
else:
|
103
|
-
"""
|
104
|
-
query if the set of gpus are fully connected by nvlink (1 hop)
|
105
|
-
"""
|
106
|
-
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
107
|
-
for i, handle in enumerate(handles):
|
108
|
-
for j, peer_handle in enumerate(handles):
|
109
|
-
if i < j:
|
110
|
-
try:
|
111
|
-
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
112
|
-
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
113
|
-
)
|
114
|
-
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
115
|
-
return False
|
116
|
-
except pynvml.NVMLError:
|
117
|
-
logger.exception(
|
118
|
-
"NVLink detection failed. This is normal if your"
|
119
|
-
" machine has no NVLink equipped."
|
120
|
-
)
|
121
|
-
return False
|
122
|
-
return True
|
123
|
-
|
124
43
|
|
125
44
|
def _can_p2p(rank: int, world_size: int) -> bool:
|
126
45
|
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang
|
@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
|
136
55
|
return True
|
137
56
|
|
138
57
|
|
139
|
-
def is_weak_contiguous(inp: torch.Tensor):
|
140
|
-
return inp.is_contiguous() or (
|
141
|
-
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
142
|
-
== inp.numel() * inp.element_size()
|
143
|
-
)
|
144
|
-
|
145
|
-
|
146
58
|
class CustomAllreduce:
|
147
59
|
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
148
60
|
_MAX_CAR_SIZE = 8192 * 1024
|
@@ -8,17 +8,44 @@ import pickle
|
|
8
8
|
import subprocess
|
9
9
|
import sys
|
10
10
|
import tempfile
|
11
|
+
from functools import wraps
|
11
12
|
from itertools import product
|
12
|
-
from typing import Dict, List, Optional, Sequence
|
13
|
+
from typing import Callable, Dict, List, Optional, Sequence, TypeVar
|
13
14
|
|
14
15
|
import torch
|
15
16
|
import torch.distributed as dist
|
16
17
|
import torch.multiprocessing as mp
|
18
|
+
from typing_extensions import ParamSpec
|
17
19
|
|
18
20
|
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
21
|
+
from sglang.srt.utils import is_cuda, is_hip
|
19
22
|
|
20
23
|
logger = logging.getLogger(__name__)
|
21
24
|
|
25
|
+
_is_cuda = is_cuda()
|
26
|
+
_is_hip = is_hip()
|
27
|
+
|
28
|
+
if _is_cuda:
|
29
|
+
try:
|
30
|
+
import pynvml
|
31
|
+
except ImportError as e:
|
32
|
+
logger.warning("Failed to import pynvml with %r", e)
|
33
|
+
|
34
|
+
if _is_hip:
|
35
|
+
try:
|
36
|
+
from amdsmi import (
|
37
|
+
AmdSmiException,
|
38
|
+
amdsmi_get_processor_handles,
|
39
|
+
amdsmi_init,
|
40
|
+
amdsmi_shut_down,
|
41
|
+
amdsmi_topo_get_link_type,
|
42
|
+
)
|
43
|
+
except ImportError as e:
|
44
|
+
logger.warning("Failed to import amdsmi with %r", e)
|
45
|
+
|
46
|
+
_P = ParamSpec("_P")
|
47
|
+
_R = TypeVar("_R")
|
48
|
+
|
22
49
|
|
23
50
|
def update_environment_variables(envs: Dict[str, str]):
|
24
51
|
for k, v in envs.items():
|
@@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
|
282
309
|
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
283
310
|
|
284
311
|
|
312
|
+
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
313
|
+
@wraps(fn)
|
314
|
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
315
|
+
if _is_hip:
|
316
|
+
try:
|
317
|
+
amdsmi_init()
|
318
|
+
return fn(*args, **kwargs)
|
319
|
+
finally:
|
320
|
+
amdsmi_shut_down()
|
321
|
+
else:
|
322
|
+
pynvml.nvmlInit()
|
323
|
+
try:
|
324
|
+
return fn(*args, **kwargs)
|
325
|
+
finally:
|
326
|
+
pynvml.nvmlShutdown()
|
327
|
+
|
328
|
+
return wrapper
|
329
|
+
|
330
|
+
|
331
|
+
@with_nvml_context
|
332
|
+
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
333
|
+
if _is_hip:
|
334
|
+
"""
|
335
|
+
query if the set of gpus are fully connected by xgmi (1 hop)
|
336
|
+
"""
|
337
|
+
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
338
|
+
for i, handle in enumerate(handles):
|
339
|
+
for j, peer_handle in enumerate(handles):
|
340
|
+
if i < j:
|
341
|
+
try:
|
342
|
+
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
343
|
+
# type is 2 for XGMI
|
344
|
+
if link_type["hops"] != 1 or link_type["type"] != 2:
|
345
|
+
return False
|
346
|
+
except AmdSmiException as error:
|
347
|
+
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
348
|
+
return False
|
349
|
+
return True
|
350
|
+
else:
|
351
|
+
"""
|
352
|
+
query if the set of gpus are fully connected by nvlink (1 hop)
|
353
|
+
"""
|
354
|
+
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
355
|
+
for i, handle in enumerate(handles):
|
356
|
+
for j, peer_handle in enumerate(handles):
|
357
|
+
if i < j:
|
358
|
+
try:
|
359
|
+
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
360
|
+
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
361
|
+
)
|
362
|
+
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
363
|
+
return False
|
364
|
+
except pynvml.NVMLError:
|
365
|
+
logger.exception(
|
366
|
+
"NVLink detection failed. This is normal if your"
|
367
|
+
" machine has no NVLink equipped."
|
368
|
+
)
|
369
|
+
return False
|
370
|
+
return True
|
371
|
+
|
372
|
+
|
373
|
+
def is_weak_contiguous(inp: torch.Tensor):
|
374
|
+
return inp.is_contiguous() or (
|
375
|
+
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
376
|
+
== inp.numel() * inp.element_size()
|
377
|
+
)
|
378
|
+
|
379
|
+
|
285
380
|
__all__ = ["gpu_p2p_access_check"]
|
286
381
|
|
287
382
|
if __name__ == "__main__":
|