sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import (
11
11
  ScatterMode,
12
12
  )
13
13
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
14
- from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
14
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
15
15
  from sglang.srt.managers.schedule_batch import global_server_args_dict
16
16
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
17
17
  from sglang.srt.operations import execute_operations, execute_overlapped_operations
@@ -479,7 +479,9 @@ def _model_forward_tbo(
479
479
  )
480
480
  del inputs
481
481
 
482
- with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
482
+ with deep_gemm_wrapper.configure_deep_gemm_num_sms(
483
+ operations_strategy.deep_gemm_num_sms
484
+ ):
483
485
  outputs_arr = execute_overlapped_operations(
484
486
  inputs_arr=inputs_arr,
485
487
  operations_arr=[operations_strategy.operations] * 2,
sglang/srt/utils.py CHANGED
@@ -17,6 +17,7 @@ import base64
17
17
  import builtins
18
18
  import ctypes
19
19
  import dataclasses
20
+ import functools
20
21
  import importlib
21
22
  import io
22
23
  import ipaddress
@@ -159,7 +160,7 @@ def is_npu() -> bool:
159
160
  return hasattr(torch, "npu") and torch.npu.is_available()
160
161
 
161
162
 
162
- def is_cpu() -> bool:
163
+ def is_host_cpu_x86() -> bool:
163
164
  machine = platform.machine().lower()
164
165
  return (
165
166
  machine in ("x86_64", "amd64", "i386", "i686")
@@ -168,6 +169,10 @@ def is_cpu() -> bool:
168
169
  )
169
170
 
170
171
 
172
+ def is_cpu() -> bool:
173
+ return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
174
+
175
+
171
176
  def is_flashinfer_available():
172
177
  """
173
178
  Check whether flashinfer is available.
@@ -837,6 +842,7 @@ class CustomCacheManager(FileCacheManager):
837
842
 
838
843
 
839
844
  def set_ulimit(target_soft_limit=65535):
845
+ # number of open files
840
846
  resource_type = resource.RLIMIT_NOFILE
841
847
  current_soft, current_hard = resource.getrlimit(resource_type)
842
848
 
@@ -846,6 +852,18 @@ def set_ulimit(target_soft_limit=65535):
846
852
  except ValueError as e:
847
853
  logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
848
854
 
855
+ # stack size
856
+ resource_type = resource.RLIMIT_STACK
857
+ current_soft, current_hard = resource.getrlimit(resource_type)
858
+ target_soft_limit_stack_size = 1024 * target_soft_limit
859
+ if current_soft < target_soft_limit_stack_size:
860
+ try:
861
+ resource.setrlimit(
862
+ resource_type, (target_soft_limit_stack_size, current_hard)
863
+ )
864
+ except ValueError as e:
865
+ logger.warning(f"Fail to set RLIMIT_STACK: {e}")
866
+
849
867
 
850
868
  def add_api_key_middleware(app, api_key: str):
851
869
  @app.middleware("http")
@@ -1277,6 +1295,15 @@ def get_hpu_memory_capacity():
1277
1295
  )
1278
1296
 
1279
1297
 
1298
+ def get_npu_memory_capacity():
1299
+ try:
1300
+ import torch_npu
1301
+
1302
+ return torch.npu.mem_get_info()[1] // 1024 // 1024 # unit: MB
1303
+ except ImportError as e:
1304
+ raise ImportError("torch_npu is required when run on npu device.")
1305
+
1306
+
1280
1307
  def get_device_memory_capacity(device: str = None):
1281
1308
  if is_cuda():
1282
1309
  gpu_mem = get_nvgpu_memory_capacity()
@@ -1284,6 +1311,8 @@ def get_device_memory_capacity(device: str = None):
1284
1311
  gpu_mem = get_amdgpu_memory_capacity()
1285
1312
  elif device == "hpu":
1286
1313
  gpu_mem = get_hpu_memory_capacity()
1314
+ elif device == "npu":
1315
+ gpu_mem = get_npu_memory_capacity()
1287
1316
  else:
1288
1317
  # GPU memory is not known yet or no GPU is available.
1289
1318
  gpu_mem = None
@@ -1373,6 +1402,11 @@ def print_warning_once(msg: str) -> None:
1373
1402
  logger.warning(msg, stacklevel=2)
1374
1403
 
1375
1404
 
1405
+ @functools.lru_cache(None)
1406
+ def print_info_once(msg: str) -> None:
1407
+ logger.info(msg)
1408
+
1409
+
1376
1410
  def get_device_name(device_id: int = 0) -> str:
1377
1411
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1378
1412
  return torch.cuda.get_device_name(device_id)
@@ -1404,6 +1438,11 @@ def get_device(device_id: Optional[int] = None) -> str:
1404
1438
  return "xpu"
1405
1439
  return "xpu:{}".format(device_id)
1406
1440
 
1441
+ if hasattr(torch, "npu") and torch.npu.is_available():
1442
+ if device_id == None:
1443
+ return "npu"
1444
+ return "npu:{}".format(device_id)
1445
+
1407
1446
  if is_habana_available():
1408
1447
  try:
1409
1448
  import habana_frameworks.torch.hpu
@@ -1417,6 +1456,15 @@ def get_device(device_id: Optional[int] = None) -> str:
1417
1456
  "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
1418
1457
  )
1419
1458
 
1459
+ if is_cpu():
1460
+ if cpu_has_amx_support():
1461
+ logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
1462
+ else:
1463
+ logger.warning(
1464
+ "CPU device enabled, using torch native backend, low performance expected."
1465
+ )
1466
+ return "cpu"
1467
+
1420
1468
  raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
1421
1469
 
1422
1470
 
@@ -1478,15 +1526,35 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
1478
1526
  return major, minor
1479
1527
 
1480
1528
 
1529
+ def get_npu_compiler_config():
1530
+ config = {
1531
+ "frozen_parameter": True,
1532
+ "tiling_schedule_optimize": True,
1533
+ "topology_sorting_strategy": "StableRDFS",
1534
+ }
1535
+ return config
1536
+
1537
+
1481
1538
  def get_compiler_backend() -> str:
1482
1539
  if hasattr(torch, "hpu") and torch.hpu.is_available():
1483
1540
  return "hpu_backend"
1484
1541
 
1485
1542
  if hasattr(torch, "npu") and torch.npu.is_available():
1486
- import torchair
1543
+ try:
1544
+ import torchair
1545
+ import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
1546
+ from torchair.configs.compiler_config import CompilerConfig
1547
+ except ImportError as e:
1548
+ raise ImportError(
1549
+ "NPU detected, but torchair package is not installed. "
1550
+ "Please install torchair for torch.compile support on NPU."
1551
+ )
1552
+ compiler_config = CompilerConfig()
1553
+ predefined_config = get_npu_compiler_config()
1554
+ for k, v in predefined_config.items():
1555
+ setattr(compiler_config.experimental_config, k, v)
1487
1556
 
1488
- config = torchair.CompilerConfig()
1489
- npu_backend = torchair.get_npu_backend(compiler_config=config)
1557
+ npu_backend = torchair.get_npu_backend(compiler_config=compiler_config)
1490
1558
  return npu_backend
1491
1559
 
1492
1560
  return "inductor"
@@ -1849,13 +1917,6 @@ def configure_ipv6(dist_init_addr):
1849
1917
  return port, host
1850
1918
 
1851
1919
 
1852
- def rank0_log(msg: str):
1853
- from sglang.srt.distributed import get_tensor_model_parallel_rank
1854
-
1855
- if get_tensor_model_parallel_rank() == 0:
1856
- logger.info(msg)
1857
-
1858
-
1859
1920
  def rank0_print(msg: str):
1860
1921
  from sglang.srt.distributed import get_tensor_model_parallel_rank
1861
1922
 
@@ -1863,6 +1924,9 @@ def rank0_print(msg: str):
1863
1924
  print(msg, flush=True)
1864
1925
 
1865
1926
 
1927
+ rank0_log = rank0_print
1928
+
1929
+
1866
1930
  def get_cuda_version():
1867
1931
  if torch.version.cuda:
1868
1932
  return tuple(map(int, torch.version.cuda.split(".")))
@@ -2086,6 +2150,44 @@ def get_free_port():
2086
2150
  return s.getsockname()[1]
2087
2151
 
2088
2152
 
2153
+ def get_local_ip_auto() -> str:
2154
+ interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
2155
+ return (
2156
+ get_local_ip_by_nic(interface)
2157
+ if interface is not None
2158
+ else get_local_ip_by_remote()
2159
+ )
2160
+
2161
+
2162
+ def get_local_ip_by_nic(interface: str) -> str:
2163
+ try:
2164
+ import netifaces
2165
+ except ImportError as e:
2166
+ raise ImportError(
2167
+ "Environment variable SGLANG_LOCAL_IP_NIC requires package netifaces, please install it through 'pip install netifaces'"
2168
+ ) from e
2169
+
2170
+ try:
2171
+ addresses = netifaces.ifaddresses(interface)
2172
+ if netifaces.AF_INET in addresses:
2173
+ for addr_info in addresses[netifaces.AF_INET]:
2174
+ ip = addr_info.get("addr")
2175
+ if ip and ip != "127.0.0.1" and ip != "0.0.0.0":
2176
+ return ip
2177
+ if netifaces.AF_INET6 in addresses:
2178
+ for addr_info in addresses[netifaces.AF_INET6]:
2179
+ ip = addr_info.get("addr")
2180
+ if ip and not ip.startswith("fe80::") and ip != "::1":
2181
+ return ip.split("%")[0]
2182
+ except (ValueError, OSError) as e:
2183
+ raise ValueError(
2184
+ "Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2185
+ )
2186
+
2187
+ # Fallback
2188
+ return get_local_ip_by_remote()
2189
+
2190
+
2089
2191
  def get_local_ip_by_remote() -> str:
2090
2192
  # try ipv4
2091
2193
  s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@@ -2197,6 +2299,90 @@ class Withable(Generic[T]):
2197
2299
  self._value = None
2198
2300
 
2199
2301
 
2302
+ def require_mlp_tp_gather(server_args):
2303
+ """
2304
+ Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
2305
+ """
2306
+ if server_args.enable_dp_attention:
2307
+ assert server_args.dp_size > 1, "dp_size must be greater than 1"
2308
+ if (
2309
+ server_args.moe_dense_tp_size is None
2310
+ ): # TODO(ch-wan): some MoE models do not have dense layers
2311
+ return True
2312
+ elif not server_args.enable_dp_lm_head:
2313
+ return True
2314
+ elif not server_args.enable_deepep_moe:
2315
+ return True
2316
+ else:
2317
+ return (
2318
+ server_args.moe_dense_tp_size
2319
+ > server_args.tp_size // server_args.dp_size
2320
+ )
2321
+ else:
2322
+ return False
2323
+
2324
+
2325
+ def require_attn_tp_gather(server_args):
2326
+ """
2327
+ Check if the input of attention is scattered.
2328
+ """
2329
+ assert server_args.moe_dense_tp_size in [1, None]
2330
+ if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
2331
+ if server_args.enable_dp_attention:
2332
+ return server_args.dp_size < server_args.tp_size
2333
+ else:
2334
+ return True
2335
+ else:
2336
+ return False
2337
+
2338
+
2339
+ def require_gathered_buffer(server_args):
2340
+ return require_mlp_tp_gather(server_args) or require_attn_tp_gather(server_args)
2341
+
2342
+
2343
+ def require_mlp_sync(server_args):
2344
+ return server_args.enable_dp_attention or require_gathered_buffer(server_args)
2345
+
2346
+
2347
+ def merge_bias_tensor(
2348
+ lhs: Optional[torch.Tensor],
2349
+ rhs: Optional[torch.Tensor],
2350
+ bs1: int,
2351
+ bs2: int,
2352
+ device: str,
2353
+ default: float,
2354
+ ):
2355
+ """Merge two bias tensors for batch merging.
2356
+
2357
+ Args:
2358
+ lhs: Left-hand side tensor
2359
+ rhs: Right-hand side tensor
2360
+ bs1: Batch size of left-hand side tensor
2361
+ bs2: Batch size of right-hand side tensor
2362
+ device: Device to place the merged tensor on
2363
+ default: Default value for missing tensor elements
2364
+
2365
+ Returns:
2366
+ Merged tensor or None if both inputs are None
2367
+ """
2368
+ if lhs is None and rhs is None:
2369
+ return None
2370
+
2371
+ if lhs is not None and rhs is not None:
2372
+ return torch.cat([lhs, rhs])
2373
+ else:
2374
+ if lhs is not None:
2375
+ shape, dtype = lhs.shape[1:], lhs.dtype
2376
+ else:
2377
+ shape, dtype = rhs.shape[1:], rhs.dtype
2378
+
2379
+ if lhs is None:
2380
+ lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
2381
+ if rhs is None:
2382
+ rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
2383
+ return torch.cat([lhs, rhs])
2384
+
2385
+
2200
2386
  def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
2201
2387
  import huggingface_hub as hf
2202
2388
 
@@ -2282,3 +2468,41 @@ class LazyValue:
2282
2468
  self._value = self._creator()
2283
2469
  self._creator = None
2284
2470
  return self._value
2471
+
2472
+
2473
+ def dynamic_import(func_path: str):
2474
+ parts = func_path.split(".")
2475
+ if len(parts) < 2:
2476
+ raise ValueError(
2477
+ "func_path should contain both module name and func name (such as 'module.func')"
2478
+ )
2479
+ module_path = ".".join(parts[:-1])
2480
+ func_name = parts[-1]
2481
+ module = importlib.import_module(module_path)
2482
+ func = getattr(module, func_name)
2483
+ return func
2484
+
2485
+
2486
+ def configure_gc_logger():
2487
+ logger.info("Enable GC Logger")
2488
+
2489
+ import gc
2490
+
2491
+ gc_start_time = {}
2492
+
2493
+ def gc_callback(phase, info):
2494
+ gen = info.get("generation", "?")
2495
+ if phase == "start":
2496
+ gc_start_time[gen] = time.time()
2497
+ logger.info(f"GC start: Time {time.time()} | Generation {gen}")
2498
+ elif phase == "stop":
2499
+ duration = time.time() - gc_start_time.get(gen, time.time())
2500
+ collected = info.get("collected", "?")
2501
+ uncollectable = info.get("uncollectable", "?")
2502
+ logger.info(
2503
+ f"GC end: Time {time.time()} | Generation {gen} | "
2504
+ f"Duration: {duration:.4f}s | Collected: {collected} | Uncollectable: {uncollectable} "
2505
+ f'{"(LONG GC)" if duration > 0.1 else ""}'
2506
+ )
2507
+
2508
+ gc.callbacks.append(gc_callback)
@@ -2,6 +2,8 @@ import unittest
2
2
 
3
3
  import torch
4
4
 
5
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
6
+ from sglang.srt.layers.radix_attention import RadixAttention
5
7
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
6
8
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
7
9
  from sglang.test.test_utils import CustomTestCase
sglang/test/runners.py CHANGED
@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
42
42
  # the output of gemma-2-2b from SRT is unstable on the commented prompt
43
43
  # "The capital of France is",
44
44
  ]
45
+ TEST_RERANK_QUERY_DOCS = [
46
+ {
47
+ "query": "How many people live in Berlin?",
48
+ "documents": [
49
+ "Berlin is well known for its museums.",
50
+ ],
51
+ },
52
+ {
53
+ "query": "How many people live in Berlin?",
54
+ "documents": [
55
+ "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
56
+ "Berlin is well known for its museums.",
57
+ ],
58
+ },
59
+ ]
45
60
 
46
61
  dirpath = os.path.dirname(__file__)
47
62
  with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
@@ -241,7 +256,7 @@ class HFRunner:
241
256
  self.model = _get_sentence_transformer_embedding_model(
242
257
  model_path, torch_dtype
243
258
  )
244
- elif self.model_type == "reward":
259
+ elif self.model_type == "reward" or self.model_type == "cross_encoder":
245
260
  from transformers import AutoModelForSequenceClassification
246
261
 
247
262
  self.model = AutoModelForSequenceClassification.from_pretrained(
@@ -303,6 +318,15 @@ class HFRunner:
303
318
  else:
304
319
  logits = self.model.encode(prompts).tolist()
305
320
  out_queue.put(ModelOutput(embed_logits=logits))
321
+ elif self.model_type == "cross_encoder":
322
+ inputs = self.tokenizer(
323
+ prompts, padding=True, return_tensors="pt"
324
+ ).to("cuda")
325
+ scores = self.model(**inputs).logits
326
+ scores = scores.squeeze().tolist()
327
+ if not isinstance(scores, list):
328
+ scores = [scores]
329
+ out_queue.put(ModelOutput(scores=scores))
306
330
 
307
331
  elif self.model_type == "reward":
308
332
  scores = []
@@ -322,7 +346,9 @@ class HFRunner:
322
346
 
323
347
  def forward(
324
348
  self,
325
- prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
349
+ prompts: Union[
350
+ List[List[str]], List[str], List[torch.Tensor]
351
+ ] = DEFAULT_PROMPTS,
326
352
  image_data: Optional[List[str]] = None,
327
353
  max_new_tokens: int = 8,
328
354
  lora_paths: Optional[List[str]] = None,
@@ -526,7 +552,9 @@ class SRTRunner:
526
552
 
527
553
  def forward(
528
554
  self,
529
- prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
555
+ prompts: Union[
556
+ List[List[str]], List[str], List[torch.Tensor]
557
+ ] = DEFAULT_PROMPTS,
530
558
  image_data: Optional[List[str]] = None,
531
559
  max_new_tokens: int = 8,
532
560
  lora_paths: Optional[List[str]] = None,
@@ -552,6 +580,13 @@ class SRTRunner:
552
580
  else:
553
581
  logits = [response["embedding"]]
554
582
  return ModelOutput(embed_logits=logits)
583
+ # cross encoder model
584
+ elif self.model_type == "cross_encoder":
585
+ response = self.engine.rerank(prompts)
586
+ if not isinstance(response, list):
587
+ response = [response]
588
+ scores = [x["embedding"] for x in response]
589
+ return ModelOutput(scores=scores)
555
590
  # reward model
556
591
  else:
557
592
  response = self.engine.encode(prompts)
@@ -343,6 +343,7 @@ class TestW8A8BlockFP8Matmul(CustomTestCase):
343
343
  OUT_DTYPES = [torch.bfloat16]
344
344
  M = [64, 128, 512, 1024, 4096]
345
345
  NKs = [
346
+ (2112, 7168),
346
347
  (1536, 7168),
347
348
  (3072, 1536),
348
349
  (24576, 7168),