sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -99,8 +99,12 @@ class MooncakeKVManager(BaseKVManager):
99
99
  disaggregation_mode: DisaggregationMode,
100
100
  server_args: ServerArgs,
101
101
  ):
102
- self.engine = MooncakeTransferEngine()
103
102
  self.kv_args = args
103
+ self.engine = MooncakeTransferEngine(
104
+ hostname=get_local_ip_by_remote(),
105
+ gpu_id=self.kv_args.gpu_id,
106
+ ib_device=self.kv_args.ib_device,
107
+ )
104
108
  self.disaggregation_mode = disaggregation_mode
105
109
  # for p/d multi node infer
106
110
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
@@ -387,6 +391,10 @@ class MooncakeKVSender(BaseKVSender):
387
391
 
388
392
 
389
393
  class MooncakeKVReceiver(BaseKVReceiver):
394
+ _ctx = zmq.Context()
395
+ _socket_cache = {}
396
+ _socket_locks = {}
397
+ _global_lock = threading.Lock()
390
398
 
391
399
  def __init__(
392
400
  self,
@@ -436,11 +444,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
436
444
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
437
445
  return None
438
446
 
439
- @cache
440
- def _connect(self, endpoint: str):
441
- socket = zmq.Context().socket(zmq.PUSH)
442
- socket.connect(endpoint)
443
- return socket
447
+ @classmethod
448
+ def _connect(cls, endpoint: str):
449
+ with cls._global_lock:
450
+ if endpoint not in cls._socket_cache:
451
+ sock = cls._ctx.socket(zmq.PUSH)
452
+ sock.connect(endpoint)
453
+ cls._socket_cache[endpoint] = sock
454
+ cls._socket_locks[endpoint] = threading.Lock()
455
+ return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
444
456
 
445
457
  def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
446
458
  self.prefill_server_url = (
@@ -456,18 +468,20 @@ class MooncakeKVReceiver(BaseKVReceiver):
456
468
  packed_aux_data_ptrs = b"".join(
457
469
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
458
470
  )
459
- self._connect("tcp://" + self.prefill_server_url).send_multipart(
460
- [
461
- str(self.bootstrap_room).encode("ascii"),
462
- get_local_ip_by_remote().encode("ascii"),
463
- str(self.kv_mgr.rank_port).encode("ascii"),
464
- self.session_id.encode("ascii"),
465
- packed_kv_data_ptrs,
466
- kv_indices.tobytes(),
467
- packed_aux_data_ptrs,
468
- str(aux_index).encode("ascii"),
469
- ]
470
- )
471
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
472
+ with lock:
473
+ sock.send_multipart(
474
+ [
475
+ str(self.bootstrap_room).encode("ascii"),
476
+ get_local_ip_by_remote().encode("ascii"),
477
+ str(self.kv_mgr.rank_port).encode("ascii"),
478
+ self.session_id.encode("ascii"),
479
+ packed_kv_data_ptrs,
480
+ kv_indices.tobytes(),
481
+ packed_aux_data_ptrs,
482
+ str(aux_index).encode("ascii"),
483
+ ]
484
+ )
471
485
 
472
486
  def poll(self) -> KVPoll:
473
487
  return self.kv_mgr.check_status(self.bootstrap_room)
@@ -493,52 +507,8 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
493
507
  self.thread.start()
494
508
 
495
509
  def _setup_routes(self):
496
- self.app.router.add_route("*", "/metadata", self._handle_metadata)
497
510
  self.app.router.add_route("*", "/route", self._handle_route)
498
511
 
499
- async def _handle_metadata(self, request: web.Request):
500
- key = request.query.get("key", "")
501
-
502
- if request.method == "GET":
503
- return await self._handle_metadata_get(key)
504
- elif request.method == "PUT":
505
- return await self._handle_metadata_put(key, request)
506
- elif request.method == "DELETE":
507
- return await self._handle_metadata_delete(key)
508
- return web.Response(
509
- text="Method not allowed", status=405, content_type="application/json"
510
- )
511
-
512
- async def _handle_metadata_get(self, key):
513
- async with self.lock:
514
- value = self.store.get(key)
515
- if value is None:
516
- return web.Response(
517
- text="metadata not found", status=404, content_type="application/json"
518
- )
519
- return web.Response(body=value, status=200, content_type="application/json")
520
-
521
- async def _handle_metadata_put(self, key, request):
522
- data = await request.read()
523
- async with self.lock:
524
- self.store[key] = data
525
- return web.Response(
526
- text="metadata updated", status=200, content_type="application/json"
527
- )
528
-
529
- async def _handle_metadata_delete(self, key):
530
- async with self.lock:
531
- if key not in self.store:
532
- return web.Response(
533
- text="metadata not found",
534
- status=404,
535
- content_type="application/json",
536
- )
537
- del self.store[key]
538
- return web.Response(
539
- text="metadata deleted", status=200, content_type="application/json"
540
- )
541
-
542
512
  async def _handle_route(self, request: web.Request):
543
513
  method = request.method
544
514
  if method == "PUT":
@@ -1,45 +1,14 @@
1
1
  import json
2
2
  import logging
3
- import os
4
- import uuid
5
3
  from dataclasses import dataclass
4
+ from typing import Optional
6
5
 
7
6
  logger = logging.getLogger(__name__)
8
7
 
9
8
 
10
- @dataclass
11
- class MooncakeTransferEngineConfig:
12
- local_hostname: str
13
- metadata_server: str
14
- protocol: str
15
- device_name: str
16
-
17
- @staticmethod
18
- def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
19
- """Load the config from a JSON file."""
20
- with open(file_path) as fin:
21
- config = json.load(fin)
22
- return MooncakeTransferEngineConfig(
23
- local_hostname=config.get("local_hostname", None),
24
- metadata_server=config.get("metadata_server"),
25
- protocol=config.get("protocol", "rdma"),
26
- device_name=config.get("device_name", ""),
27
- )
28
-
29
- @staticmethod
30
- def load_from_env() -> "MooncakeTransferEngineConfig":
31
- """Load config from a file specified in the environment variable."""
32
- config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
33
- if config_file_path is None:
34
- raise ValueError(
35
- "The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
36
- )
37
- return MooncakeTransferEngineConfig.from_file(config_file_path)
38
-
39
-
40
9
  class MooncakeTransferEngine:
41
10
 
42
- def __init__(self):
11
+ def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
43
12
  try:
44
13
  from mooncake.engine import TransferEngine
45
14
  except ImportError as e:
@@ -50,43 +19,43 @@ class MooncakeTransferEngine:
50
19
  ) from e
51
20
 
52
21
  self.engine = TransferEngine()
22
+ self.hostname = hostname
23
+ self.gpu_id = gpu_id
24
+ self.ib_device = ib_device
53
25
 
54
- try:
55
- self.config = MooncakeTransferEngineConfig.load_from_env()
56
- logger.info("Mooncake Configuration loaded successfully.")
57
- except ValueError as e:
58
- logger.error(e)
59
- raise
60
- except Exception as exc:
61
- logger.error("An error occurred while loading the configuration: %s", exc)
62
- raise
63
-
64
- self.config = MooncakeTransferEngineConfig.load_from_env()
65
-
66
- session_suffix = "_" + str(uuid.uuid4())
67
- self.session_id = self.config.local_hostname + session_suffix
68
26
  self.initialize(
69
- self.session_id,
70
- self.config.metadata_server,
71
- self.config.protocol,
72
- self.config.device_name,
27
+ hostname=self.hostname,
28
+ device_name=self.ib_device,
73
29
  )
30
+ self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
74
31
 
75
32
  def register(self, ptr, length):
76
- self.engine.register_memory(ptr, length)
33
+ ret_value = self.engine.register_memory(ptr, length)
34
+ if ret_value != 0:
35
+ logger.error("Mooncake memory registration failed.")
36
+ raise RuntimeError("Mooncake memory registration failed.")
77
37
 
78
38
  def deregister(self, ptr):
79
- self.engine.unregister_memory(ptr)
39
+ ret_value = self.engine.unregister_memory(ptr)
40
+ if ret_value != 0:
41
+ logger.error("Mooncake memory deregistration failed.")
42
+ raise RuntimeError("Mooncake memory deregistration failed.")
80
43
 
81
44
  def initialize(
82
45
  self,
83
- local_hostname: str,
84
- metadata_server: str,
85
- protocol: str,
86
- device_name: str,
46
+ hostname: str,
47
+ device_name: Optional[str],
87
48
  ) -> None:
88
49
  """Initialize the mooncake instance."""
89
- self.engine.initialize(local_hostname, metadata_server, protocol, device_name)
50
+ ret_value = self.engine.initialize(
51
+ hostname,
52
+ "P2PHANDSHAKE",
53
+ "rdma",
54
+ device_name if device_name is not None else "",
55
+ )
56
+ if ret_value != 0:
57
+ logger.error("Mooncake Transfer Engine initialization failed.")
58
+ raise RuntimeError("Mooncake Transfer Engine initialization failed.")
90
59
 
91
60
  def transfer_sync(
92
61
  self, session_id: str, buffer: int, peer_buffer_address: int, length: int
@@ -97,12 +66,12 @@ class MooncakeTransferEngine:
97
66
  session_id, buffer, peer_buffer_address, length
98
67
  )
99
68
  if ret < 0:
100
- logger.error("Transfer Return Error")
101
- raise Exception("Transfer Return Error")
69
+ logger.error("Mooncake Transfer Engine Return Error.")
70
+ raise RuntimeError("Mooncake Transfer Engine Return Error.")
102
71
  return ret
103
72
 
104
73
  def get_localhost(self):
105
- return self.config.local_hostname
74
+ return self.hostname
106
75
 
107
76
  def get_session_id(self):
108
77
  return self.session_id
@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import (
31
31
  ReqToMetadataIdxAllocator,
32
32
  TransferBackend,
33
33
  get_kv_class,
34
+ kv_to_page_indices,
35
+ kv_to_page_num,
34
36
  poll_and_all_reduce,
35
37
  )
36
38
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
@@ -103,7 +105,7 @@ class PrefillBootstrapQueue:
103
105
  kv_args.aux_item_lens = [
104
106
  metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
105
107
  ]
106
- kv_args.ib_device = "mock-ib-device"
108
+ kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
107
109
  kv_args.gpu_id = self.scheduler.gpu_id
108
110
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
109
111
  kv_manager = kv_manager_class(
@@ -154,7 +156,8 @@ class PrefillBootstrapQueue:
154
156
  self.req_to_metadata_buffer_idx_allocator.alloc()
155
157
  )
156
158
  assert req.metadata_buffer_index is not None
157
- req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
159
+ num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
160
+ req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
158
161
 
159
162
  bootstrapped_reqs.append(req)
160
163
  indices_to_remove.add(i)
@@ -171,6 +174,36 @@ class SchedulerDisaggregationPrefillMixin:
171
174
  Mixin for Scheduler to handle disaggregation prefill
172
175
  """
173
176
 
177
+ @torch.no_grad()
178
+ def event_loop_normal_disagg_prefill(self):
179
+ """A normal scheduler loop for prefill worker in disaggregation mode."""
180
+
181
+ while True:
182
+ recv_reqs = self.recv_requests()
183
+ self.process_input_requests(recv_reqs)
184
+ self.waiting_queue.extend(
185
+ self.disagg_prefill_pending_queue.pop_bootstrapped()
186
+ )
187
+ self.process_prefill_chunk()
188
+ batch = self.get_new_batch_prefill()
189
+ self.cur_batch = batch
190
+
191
+ if batch:
192
+ result = self.run_batch(batch)
193
+ self.process_batch_result_disagg_prefill(batch, result)
194
+
195
+ if len(self.disagg_prefill_inflight_queue) > 0:
196
+ self.process_disagg_prefill_inflight_queue()
197
+
198
+ if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
199
+ self.check_memory()
200
+ self.new_token_ratio = self.init_new_token_ratio
201
+
202
+ self.last_batch = batch
203
+ # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
204
+ # Otherwise, it hangs under high concurrency
205
+ self.running_batch.batch_is_full = False
206
+
174
207
  def process_batch_result_disagg_prefill(
175
208
  self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
176
209
  ) -> None:
@@ -210,7 +243,7 @@ class SchedulerDisaggregationPrefillMixin:
210
243
 
211
244
  polls = poll_and_all_reduce(
212
245
  [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
213
- self.tp_worker.get_tp_cpu_group(),
246
+ self.attn_tp_cpu_group,
214
247
  )
215
248
 
216
249
  undone_reqs: List[Req] = []
@@ -270,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin:
270
303
  req.metadata_buffer_index, token_id
271
304
  )
272
305
  is_last = token_id is not None
273
- req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last)
306
+ page_indices = kv_to_page_indices(
307
+ kv_indices, self.token_to_kv_pool_allocator.page_size
308
+ )
309
+ req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
@@ -4,6 +4,7 @@ from collections import deque
4
4
  from enum import Enum
5
5
  from typing import List
6
6
 
7
+ import numpy as np
7
8
  import torch
8
9
  import torch.distributed as dist
9
10
 
@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
73
74
  }
74
75
  return class_mapping.get(class_type)
75
76
  raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
77
+
78
+
79
+ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
80
+ # 1. The page is guaruanteed to be full except the last page.
81
+ # 2. page index = kv_index // page_size
82
+ # The return vector is kv_indices[::page_size] // page_size
83
+ if page_size == 1: # shortcut
84
+ return kv_indices
85
+ return kv_indices[::page_size] // page_size
86
+
87
+
88
+ def kv_to_page_num(num_kv_indices: int, page_size: int):
89
+ # ceil(num_kv_indices / page_size)
90
+ return (num_kv_indices + page_size - 1) // page_size
@@ -12,18 +12,17 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  import os
15
- from typing import Dict, List, Literal, Optional, Tuple, Union
15
+ from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  import torch.distributed as dist
19
19
  from PIL.Image import Image
20
20
  from torch.distributed.tensor import DeviceMesh, DTensor
21
21
 
22
+ from sglang.srt.entrypoints.engine import Engine
22
23
  from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
23
24
  from sglang.srt.model_executor.model_runner import LocalSerializedTensor
24
25
  from sglang.srt.patch_torch import monkey_patch_torch_reductions
25
- from sglang.srt.server import Engine
26
- from sglang.srt.server_args import PortArgs, ServerArgs
27
26
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
28
27
 
29
28
 
@@ -125,7 +124,7 @@ class VerlEngine:
125
124
 
126
125
  def update_weights_from_tensor(
127
126
  self,
128
- named_tensors: List[Tuple[str, torch.Tensor]],
127
+ named_tensors: Iterable[Tuple[str, torch.Tensor]],
129
128
  load_format: Optional[str] = None,
130
129
  ):
131
130
  # Most naive implementation, can optimize a lot if it is bottleneck
@@ -154,9 +153,12 @@ class VerlEngine:
154
153
  )
155
154
  ],
156
155
  load_format=load_format,
157
- flush_cache=tensor_index == len(named_tensors) - 1,
156
+ flush_cache=False,
158
157
  )
159
158
 
159
+ if self._tp_rank == 0:
160
+ self._engine.tokenizer_manager.flush_cache()
161
+
160
162
  def release_memory_occupation(self):
161
163
  if self._tp_rank == 0:
162
164
  self._engine.release_memory_occupation()
@@ -21,13 +21,6 @@ import torch
21
21
  import torch.nn as nn
22
22
  import torch.nn.functional as F
23
23
 
24
- from sglang.srt.utils import is_cuda_available
25
-
26
- _is_cuda = is_cuda_available()
27
-
28
- if _is_cuda:
29
- from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
30
-
31
24
  from sglang.srt.custom_op import CustomOp
32
25
  from sglang.srt.distributed import (
33
26
  divide,
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
35
28
  get_tensor_model_parallel_world_size,
36
29
  )
37
30
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
- from sglang.srt.utils import set_weight_attrs
31
+ from sglang.srt.utils import is_cuda_available, set_weight_attrs
32
+
33
+ _is_cuda = is_cuda_available()
34
+
35
+ if _is_cuda:
36
+ from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
39
37
 
40
38
  logger = logging.getLogger(__name__)
41
39