sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
286
286
  self.get_key_buffer(i).nbytes for i in range(self.layer_num)
287
287
  ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
288
288
  kv_item_lens = [
289
- self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
290
- ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
289
+ self.get_key_buffer(i)[0].nbytes * self.page_size
290
+ for i in range(self.layer_num)
291
+ ] + [
292
+ self.get_value_buffer(i)[0].nbytes * self.page_size
293
+ for i in range(self.layer_num)
294
+ ]
291
295
  return kv_data_ptrs, kv_data_lens, kv_item_lens
292
296
 
293
297
  # Todo: different memory layout
@@ -414,6 +418,7 @@ class MLATokenToKVPool(KVCache):
414
418
  enable_memory_saver: bool,
415
419
  ):
416
420
  self.size = size
421
+ self.page_size = page_size
417
422
  self.dtype = dtype
418
423
  self.device = device
419
424
  if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
@@ -441,13 +446,28 @@ class MLATokenToKVPool(KVCache):
441
446
  ]
442
447
 
443
448
  self.layer_transfer_counter = None
449
+ self.page_size = page_size
450
+
451
+ kv_size = self.get_kv_size_bytes()
452
+ logger.info(
453
+ f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
454
+ )
455
+
456
+ def get_kv_size_bytes(self):
457
+ assert hasattr(self, "kv_buffer")
458
+ kv_size_bytes = 0
459
+ for kv_cache in self.kv_buffer:
460
+ kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
461
+ return kv_size_bytes
444
462
 
445
463
  # for disagg
446
464
  def get_contiguous_buf_infos(self):
447
465
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
448
466
  kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
449
467
  kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
450
- kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
468
+ kv_item_lens = [
469
+ self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
470
+ ]
451
471
  return kv_data_ptrs, kv_data_lens, kv_item_lens
452
472
 
453
473
  def get_key_buffer(self, layer_id: int):
@@ -616,26 +636,27 @@ class HostKVCache(abc.ABC):
616
636
  self,
617
637
  device_pool: MHATokenToKVPool,
618
638
  host_to_device_ratio: float,
639
+ host_size: int,
619
640
  pin_memory: bool,
620
641
  device: str,
621
642
  page_size: int,
622
643
  ):
623
- assert (
624
- host_to_device_ratio >= 1
625
- ), "The host memory should be larger than the device memory with the current protocol"
626
- # todo, other ways of configuring the size
627
-
628
644
  self.device_pool = device_pool
629
- self.host_to_device_ratio = host_to_device_ratio
645
+ self.dtype = device_pool.store_dtype
630
646
  self.pin_memory = pin_memory
631
647
  self.device = device
632
648
  self.page_size = page_size
633
-
634
- self.size = int(device_pool.size * host_to_device_ratio)
649
+ self.size_per_token = self.get_size_per_token()
650
+ if host_size > 0:
651
+ self.size = int(host_size * 1e9 // self.size_per_token)
652
+ else:
653
+ self.size = int(device_pool.size * host_to_device_ratio)
635
654
  # Align the host memory pool size to the page size
636
655
  self.size = self.size - (self.size % self.page_size)
637
- self.dtype = device_pool.store_dtype
638
- self.size_per_token = self.get_size_per_token()
656
+
657
+ assert (
658
+ self.size > device_pool.size
659
+ ), "The host memory should be larger than the device memory with the current protocol"
639
660
 
640
661
  # Verify there is enough available host memory.
641
662
  host_mem = psutil.virtual_memory()
@@ -787,12 +808,13 @@ class MHATokenToKVPoolHost(HostKVCache):
787
808
  self,
788
809
  device_pool: MHATokenToKVPool,
789
810
  host_to_device_ratio: float,
811
+ host_size: int,
790
812
  page_size: int,
791
813
  pin_memory: bool = True,
792
814
  device: str = "cpu",
793
815
  ):
794
816
  super().__init__(
795
- device_pool, host_to_device_ratio, pin_memory, device, page_size
817
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
796
818
  )
797
819
 
798
820
  def get_size_per_token(self):
@@ -861,12 +883,13 @@ class MLATokenToKVPoolHost(HostKVCache):
861
883
  self,
862
884
  device_pool: MLATokenToKVPool,
863
885
  host_to_device_ratio: float,
886
+ host_size: int,
864
887
  page_size: int,
865
888
  pin_memory: bool = True,
866
889
  device: str = "cpu",
867
890
  ):
868
891
  super().__init__(
869
- device_pool, host_to_device_ratio, pin_memory, device, page_size
892
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
870
893
  )
871
894
 
872
895
  def get_size_per_token(self):
@@ -35,13 +35,17 @@ from sglang.srt.model_executor.forward_batch_info import (
35
35
  ForwardMode,
36
36
  )
37
37
  from sglang.srt.patch_torch import monkey_patch_torch_compile
38
- from sglang.srt.utils import get_available_gpu_memory, is_hip
39
-
40
- _is_hip = is_hip()
38
+ from sglang.srt.utils import (
39
+ get_available_gpu_memory,
40
+ get_device_memory_capacity,
41
+ is_hip,
42
+ )
41
43
 
42
44
  if TYPE_CHECKING:
43
45
  from sglang.srt.model_executor.model_runner import ModelRunner
44
46
 
47
+ _is_hip = is_hip()
48
+
45
49
 
46
50
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
47
51
  for sub in model._modules.values():
@@ -129,7 +133,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
129
133
  list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
130
134
  )
131
135
 
132
- if _is_hip:
136
+ gpu_mem = get_device_memory_capacity()
137
+ if gpu_mem is not None and gpu_mem > 81920:
133
138
  capture_bs += list(range(160, 257, 8))
134
139
 
135
140
  if max(capture_bs) > model_runner.req_to_token_pool.size:
@@ -140,12 +145,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
140
145
  ]
141
146
 
142
147
  capture_bs = list(sorted(set(capture_bs)))
143
- capture_bs = [
144
- bs
145
- for bs in capture_bs
146
- if bs <= model_runner.req_to_token_pool.size
147
- and bs <= server_args.cuda_graph_max_bs
148
- ]
148
+
149
+ assert len(capture_bs) > 0 and capture_bs[0] > 0
150
+ capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
151
+ if server_args.cuda_graph_max_bs:
152
+ capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
149
153
  compile_bs = (
150
154
  [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
151
155
  if server_args.enable_torch_compile
@@ -186,6 +190,7 @@ class CudaGraphRunner:
186
190
 
187
191
  # Batch sizes to capture
188
192
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
193
+
189
194
  self.capture_forward_mode = ForwardMode.DECODE
190
195
  self.capture_hidden_mode = CaptureHiddenMode.NULL
191
196
  self.num_tokens_per_bs = 1
@@ -42,6 +42,10 @@ from sglang.srt.layers.dp_attention import (
42
42
  )
43
43
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
44
  from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
45
+ from sglang.srt.layers.quantization.deep_gemm import (
46
+ _ENABLE_JIT_DEEPGEMM,
47
+ update_deep_gemm_config,
48
+ )
45
49
  from sglang.srt.layers.sampler import Sampler
46
50
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
47
51
  from sglang.srt.lora.lora_manager import LoRAManager
@@ -73,6 +77,7 @@ from sglang.srt.utils import (
73
77
  MultiprocessingSerializer,
74
78
  enable_show_time_cost,
75
79
  get_available_gpu_memory,
80
+ get_bool_env_var,
76
81
  init_custom_process_group,
77
82
  is_cuda,
78
83
  is_fa3_default_architecture,
@@ -127,10 +132,7 @@ class ModelRunner:
127
132
  self.page_size = server_args.page_size
128
133
  self.req_to_token_pool = req_to_token_pool
129
134
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
130
- self.use_mla_backend = (
131
- self.model_config.attention_arch == AttentionArch.MLA
132
- and not server_args.disable_mla
133
- )
135
+ self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
134
136
  self.attention_chunk_size = model_config.attention_chunk_size
135
137
 
136
138
  # Model-specific adjustment
@@ -139,18 +141,12 @@ class ModelRunner:
139
141
  if server_args.show_time_cost:
140
142
  enable_show_time_cost()
141
143
 
142
- if server_args.disable_outlines_disk_cache:
143
- from outlines.caching import disable_cache
144
-
145
- disable_cache()
146
-
147
144
  # Global vars
148
145
  global_server_args_dict.update(
149
146
  {
150
147
  "attention_backend": server_args.attention_backend,
151
148
  "sampling_backend": server_args.sampling_backend,
152
149
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
153
- "disable_mla": server_args.disable_mla,
154
150
  "torchao_config": server_args.torchao_config,
155
151
  "enable_nan_detection": server_args.enable_nan_detection,
156
152
  "enable_dp_attention": server_args.enable_dp_attention,
@@ -160,13 +156,12 @@ class ModelRunner:
160
156
  "device": server_args.device,
161
157
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
162
158
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
163
- "enable_flashmla": server_args.enable_flashmla,
164
159
  "disable_radix_cache": server_args.disable_radix_cache,
165
160
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
161
+ "moe_dense_tp_size": server_args.moe_dense_tp_size,
166
162
  "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
167
163
  "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
168
164
  "n_share_experts_fusion": server_args.n_share_experts_fusion,
169
- "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
170
165
  "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
171
166
  "use_mla_backend": self.use_mla_backend,
172
167
  }
@@ -178,6 +173,10 @@ class ModelRunner:
178
173
  # Get memory before model loading
179
174
  min_per_gpu_memory = self.init_torch_distributed()
180
175
 
176
+ # Update deep gemm configure
177
+ if _ENABLE_JIT_DEEPGEMM:
178
+ update_deep_gemm_config(gpu_id, server_args)
179
+
181
180
  # If it is a draft model tp_group can be different.
182
181
  self.initialize(min_per_gpu_memory)
183
182
 
@@ -229,16 +228,17 @@ class ModelRunner:
229
228
  def model_specific_adjustment(self):
230
229
  server_args = self.server_args
231
230
 
232
- if server_args.enable_flashinfer_mla:
233
- # TODO: remove this branch after enable_flashinfer_mla is deprecated
234
- logger.info("MLA optimization is turned on. Use flashinfer backend.")
235
- server_args.attention_backend = "flashinfer"
236
- elif server_args.enable_flashmla:
237
- # TODO: remove this branch after enable_flashmla is deprecated
238
- logger.info("MLA optimization is turned on. Use flashmla decode.")
239
- server_args.attention_backend = "flashmla"
240
- elif server_args.attention_backend is None:
241
- # By default, use flashinfer for non-mla attention and triton for mla attention
231
+ if server_args.attention_backend is None:
232
+ """
233
+ We auto select the fastest attention backend according to the current offering
234
+ 1. Models with MHA Architecture (e.g: Llama, QWen)
235
+ 1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
236
+ 1.2 In other cases, we will use flashinfer if available, otherwise use triton.
237
+ 2. Models with MLA Architecture and using FA3
238
+ 2.1 We will use FA3 backend on hopper.
239
+ 2.2 Otherwise, we will use triton backend.
240
+ """
241
+
242
242
  if not self.use_mla_backend:
243
243
  if (
244
244
  is_hopper_with_cuda_12_3()
@@ -251,9 +251,7 @@ class ModelRunner:
251
251
  "flashinfer" if is_flashinfer_available() else "triton"
252
252
  )
253
253
  else:
254
- if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
255
- server_args
256
- ):
254
+ if is_hopper_with_cuda_12_3():
257
255
  server_args.attention_backend = "fa3"
258
256
  else:
259
257
  server_args.attention_backend = "triton"
@@ -263,7 +261,12 @@ class ModelRunner:
263
261
  elif self.use_mla_backend:
264
262
  # TODO: add MLA optimization on CPU
265
263
  if server_args.device != "cpu":
266
- if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
264
+ if server_args.attention_backend in [
265
+ "flashinfer",
266
+ "fa3",
267
+ "triton",
268
+ "flashmla",
269
+ ]:
267
270
  logger.info(
268
271
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
269
272
  )
@@ -320,7 +323,6 @@ class ModelRunner:
320
323
  logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
321
324
 
322
325
  if not self.use_mla_backend:
323
- logger.info("Disable chunked prefix cache for non-MLA backend.")
324
326
  server_args.disable_chunked_prefix_cache = True
325
327
  elif self.page_size > 1:
326
328
  logger.info("Disable chunked prefix cache when page size > 1.")
@@ -387,10 +389,16 @@ class ModelRunner:
387
389
  local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
388
390
  if self.tp_size > 1:
389
391
  if min_per_gpu_memory < local_gpu_memory * 0.9:
390
- raise ValueError(
391
- "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
392
- f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
393
- )
392
+ if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
393
+ logger.warning(
394
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
395
+ f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
396
+ )
397
+ else:
398
+ raise ValueError(
399
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
400
+ f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
401
+ )
394
402
 
395
403
  logger.info(
396
404
  f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"