sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,22 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """ModelRunner runs the forward passes of the models."""
17
15
 
18
16
  import gc
19
17
  import importlib
20
18
  import importlib.resources
19
+ import inspect
21
20
  import json
22
21
  import logging
23
22
  import pkgutil
@@ -56,10 +55,13 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
55
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
57
56
  from sglang.srt.server_args import ServerArgs
58
57
  from sglang.srt.utils import (
58
+ crash_on_warnings,
59
59
  enable_show_time_cost,
60
60
  get_available_gpu_memory,
61
- monkey_patch_vllm_dummy_weight_loader,
61
+ is_hip,
62
+ monkey_patch_vllm_model_config,
62
63
  monkey_patch_vllm_p2p_access_check,
64
+ set_cpu_offload_max_bytes,
63
65
  )
64
66
 
65
67
  logger = logging.getLogger(__name__)
@@ -113,7 +115,7 @@ class ModelRunner:
113
115
  )
114
116
 
115
117
  if self.is_multimodal:
116
- logger.warning(
118
+ logger.info(
117
119
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
118
120
  )
119
121
  server_args.chunked_prefill_size = None
@@ -139,15 +141,26 @@ class ModelRunner:
139
141
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
140
142
  "disable_mla": server_args.disable_mla,
141
143
  "torchao_config": server_args.torchao_config,
142
- "disable_penalizer": server_args.disable_penalizer,
143
- "disable_nan_detection": server_args.disable_nan_detection,
144
+ "enable_nan_detection": server_args.enable_nan_detection,
145
+ "enable_dp_attention": server_args.enable_dp_attention,
144
146
  }
145
147
  )
146
148
 
147
- # Init componnets
149
+ set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
150
+
151
+ # Init components
148
152
  min_per_gpu_memory = self.init_torch_distributed()
149
153
  self.sampler = Sampler()
150
154
  self.load_model()
155
+
156
+ # Apply torch TP if model supports it
157
+ supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
158
+ if self.tp_size > 1 and supports_torch_tp:
159
+ self.apply_torch_tp()
160
+ self.torch_tp_applied = True
161
+ else:
162
+ self.torch_tp_applied = False
163
+
151
164
  if server_args.lora_paths is not None:
152
165
  self.init_lora_manager()
153
166
  self.init_memory_pool(
@@ -166,14 +179,15 @@ class ModelRunner:
166
179
  def init_torch_distributed(self):
167
180
  logger.info("Init torch distributed begin.")
168
181
  # Init torch distributed
182
+ torch.get_device_module(self.device).set_device(self.gpu_id)
169
183
  if self.device == "cuda":
170
- torch.cuda.set_device(self.gpu_id)
171
184
  backend = "nccl"
172
185
  # ToDO(liangan1):Just use gloo to bypass the initilization fail
173
186
  # Need to use xccl for xpu backend in the future
174
187
  elif self.device == "xpu":
175
- torch.xpu.set_device(self.gpu_id)
176
188
  backend = "gloo"
189
+ elif self.device == "hpu":
190
+ backend = "hccl"
177
191
 
178
192
  if not self.server_args.enable_p2p_check:
179
193
  monkey_patch_vllm_p2p_access_check(self.gpu_id)
@@ -215,6 +229,49 @@ class ModelRunner:
215
229
 
216
230
  return min_per_gpu_memory
217
231
 
232
+ def setup_model(self):
233
+ try:
234
+ from vllm.config import VllmConfig
235
+
236
+ vllm_config = VllmConfig()
237
+ vllm_config.model_config = self.vllm_model_config
238
+ vllm_config.load_config = self.load_config
239
+ vllm_config.device_config = DeviceConfig(self.device)
240
+ vllm_config.quant_config = VllmConfig._get_quantization_config(
241
+ vllm_config.model_config, vllm_config.load_config
242
+ )
243
+ return get_model(vllm_config=vllm_config)
244
+ except ImportError:
245
+ pass
246
+
247
+ return get_model(
248
+ model_config=self.vllm_model_config,
249
+ load_config=self.load_config,
250
+ device_config=DeviceConfig(self.device),
251
+ parallel_config=None,
252
+ scheduler_config=None,
253
+ lora_config=None,
254
+ cache_config=None,
255
+ )
256
+
257
+ def get_model_config_params(self):
258
+ sig = inspect.signature(VllmModelConfig.__init__)
259
+ params = {
260
+ "model": self.server_args.model_path,
261
+ "quantization": self.server_args.quantization,
262
+ "tokenizer": None,
263
+ "tokenizer_mode": None,
264
+ "trust_remote_code": self.server_args.trust_remote_code,
265
+ "dtype": self.server_args.dtype,
266
+ "seed": self.server_args.random_seed,
267
+ "skip_tokenizer_init": True,
268
+ }
269
+
270
+ if "task" in sig.parameters:
271
+ params["task"] = ""
272
+
273
+ return params
274
+
218
275
  def load_model(self):
219
276
  logger.info(
220
277
  f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -232,42 +289,25 @@ class ModelRunner:
232
289
  raise RuntimeError("SGLang only supports sm75 and above.")
233
290
 
234
291
  # Prepare the vllm model config
235
- monkey_patch_vllm_dummy_weight_loader()
236
292
  self.load_config = LoadConfig(
237
293
  load_format=self.server_args.load_format,
238
294
  download_dir=self.server_args.download_dir,
239
295
  )
240
- self.vllm_model_config = VllmModelConfig(
241
- model=self.server_args.model_path,
242
- quantization=self.server_args.quantization,
243
- tokenizer=None,
244
- tokenizer_mode=None,
245
- trust_remote_code=self.server_args.trust_remote_code,
246
- dtype=self.server_args.dtype,
247
- seed=self.server_args.random_seed,
248
- skip_tokenizer_init=True,
249
- )
296
+ monkey_patch_vllm_model_config()
297
+ self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
250
298
  if self.model_config.model_override_args is not None:
251
299
  self.vllm_model_config.hf_config.update(
252
300
  self.model_config.model_override_args
253
301
  )
254
- self.dtype = self.vllm_model_config.dtype
255
302
 
256
- # Load the model
257
- self.model = get_model(
258
- model_config=self.vllm_model_config,
259
- load_config=self.load_config,
260
- device_config=DeviceConfig(self.device),
261
- parallel_config=None,
262
- scheduler_config=None,
263
- lora_config=None,
264
- cache_config=None,
265
- )
303
+ self.model = self.setup_model()
304
+
266
305
  self.sliding_window_size = (
267
306
  self.model.get_attention_sliding_window_size()
268
307
  if hasattr(self.model, "get_attention_sliding_window_size")
269
308
  else None
270
309
  )
310
+ self.dtype = self.vllm_model_config.dtype
271
311
 
272
312
  logger.info(
273
313
  f"Load weight end. "
@@ -293,17 +333,9 @@ class ModelRunner:
293
333
  target_device = torch.device(self.device)
294
334
 
295
335
  try:
296
- # TODO: Use a better method to check this
297
- vllm_model_config = VllmModelConfig(
298
- model=model_path,
299
- quantization=self.server_args.quantization,
300
- tokenizer=None,
301
- tokenizer_mode=None,
302
- trust_remote_code=self.server_args.trust_remote_code,
303
- dtype=self.server_args.dtype,
304
- seed=self.server_args.random_seed,
305
- skip_tokenizer_init=True,
306
- )
336
+ model_config_params = self.get_model_config_params()
337
+ model_config_params["model"] = model_path
338
+ vllm_model_config = VllmModelConfig(**model_config_params)
307
339
  except Exception as e:
308
340
  message = f"Failed to load model config: {e}."
309
341
  return False, message
@@ -412,7 +444,10 @@ class ModelRunner:
412
444
  if self.server_args.kv_cache_dtype == "auto":
413
445
  self.kv_cache_dtype = self.dtype
414
446
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
415
- self.kv_cache_dtype = torch.float8_e5m2
447
+ if is_hip(): # Using natively supported format
448
+ self.kv_cache_dtype = torch.float8_e5m2fnuz
449
+ else:
450
+ self.kv_cache_dtype = torch.float8_e5m2
416
451
  else:
417
452
  raise ValueError(
418
453
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -551,6 +586,13 @@ class ModelRunner:
551
586
  logger.info("Capture cuda graph begin. This can take up to several minutes.")
552
587
  self.cuda_graph_runner = CudaGraphRunner(self)
553
588
 
589
+ def apply_torch_tp(self):
590
+ logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
591
+ from sglang.srt.model_parallel import tensor_parallel
592
+
593
+ device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
594
+ tensor_parallel(self.model, device_mesh)
595
+
554
596
  def forward_decode(self, forward_batch: ForwardBatch):
555
597
  if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
556
598
  return self.cuda_graph_runner.replay(forward_batch)
@@ -564,9 +606,17 @@ class ModelRunner:
564
606
  def forward_extend(self, forward_batch: ForwardBatch):
565
607
  self.attn_backend.init_forward_metadata(forward_batch)
566
608
  if self.is_generation:
567
- return self.model.forward(
568
- forward_batch.input_ids, forward_batch.positions, forward_batch
569
- )
609
+ if forward_batch.input_embeds is None:
610
+ return self.model.forward(
611
+ forward_batch.input_ids, forward_batch.positions, forward_batch
612
+ )
613
+ else:
614
+ return self.model.forward(
615
+ forward_batch.input_ids,
616
+ forward_batch.positions,
617
+ forward_batch,
618
+ input_embeds=forward_batch.input_embeds.bfloat16(),
619
+ )
570
620
  else:
571
621
  # Only embedding models have get_embedding parameter
572
622
  return self.model.forward(
@@ -576,21 +626,37 @@ class ModelRunner:
576
626
  get_embedding=True,
577
627
  )
578
628
 
629
+ def forward_idle(self, forward_batch: ForwardBatch):
630
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
631
+ return self.cuda_graph_runner.replay(forward_batch)
632
+
633
+ return self.model.forward(
634
+ forward_batch.input_ids, forward_batch.positions, forward_batch
635
+ )
636
+
579
637
  def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
580
638
  if forward_batch.forward_mode.is_decode():
581
639
  return self.forward_decode(forward_batch)
582
640
  elif forward_batch.forward_mode.is_extend():
583
641
  return self.forward_extend(forward_batch)
642
+ elif forward_batch.forward_mode.is_idle():
643
+ return self.forward_idle(forward_batch)
584
644
  else:
585
645
  raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
586
646
 
587
647
  def sample(
588
648
  self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
589
649
  ) -> torch.Tensor:
590
- # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
591
650
  sampling_info = forward_batch.sampling_info
592
- sampling_info.update_regex_vocab_mask()
593
- sampling_info.update_penalties()
651
+ if sampling_info.sampling_info_done:
652
+ # Overlap mode: the function update_regex_vocab_mask was executed
653
+ # in process_batch_result of the last batch.
654
+ if sampling_info.grammars:
655
+ sampling_info.sampling_info_done.wait()
656
+ else:
657
+ # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
658
+ sampling_info.update_regex_vocab_mask()
659
+ sampling_info.update_penalties()
594
660
  logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
595
661
 
596
662
  # Sample the next tokens.
@@ -616,7 +682,7 @@ class ModelRunner:
616
682
 
617
683
  # Apply regex vocab_mask
618
684
  if sampling_info.vocab_mask is not None:
619
- logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
685
+ sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
620
686
 
621
687
  return logits
622
688
 
@@ -640,7 +706,9 @@ def import_model_classes():
640
706
  try:
641
707
  module = importlib.import_module(name)
642
708
  except Exception as e:
643
- logger.warning(f"Ignore import error when loading {name}. " f"{e}")
709
+ logger.warning(f"Ignore import error when loading {name}. {e}")
710
+ if crash_on_warnings():
711
+ raise ValueError(f"Ignore import error when loading {name}. {e}")
644
712
  continue
645
713
  if hasattr(module, "EntryClass"):
646
714
  entry = module.EntryClass
@@ -0,0 +1,98 @@
1
+ """
2
+ Common utilities for torch model parallelism.
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from torch.distributed.device_mesh import DeviceMesh
9
+
10
+ try:
11
+ from torch.distributed.tensor import DTensor, Shard
12
+ except ImportError:
13
+ # torch 2.4 or older
14
+ from torch.distributed._tensor import DTensor, Shard
15
+
16
+ from torch.distributed._functional_collectives import AsyncCollectiveTensor
17
+ from torch.distributed.tensor.parallel import (
18
+ ColwiseParallel,
19
+ RowwiseParallel,
20
+ parallelize_module,
21
+ )
22
+
23
+
24
+ class ColwiseParallelSharded(ColwiseParallel):
25
+ """
26
+ A version of ColwiseParallel where the local weight has been already
27
+ sharded. This is used for the fused wqkv case, where during loading, we
28
+ already sharded wq, wk, wv before fusing them.
29
+ """
30
+
31
+ # Override the _partition_linear_fn in ColwiseParallel
32
+ def _partition_linear_fn(self, name, module, device_mesh):
33
+ # colwise shard weight/bias to Shard(0), weight be Shard(0)
34
+ # means Colwise as Linear is input * weight^T + bias, where
35
+ # weight would become Shard(1)
36
+ for name, param in module.named_parameters():
37
+ dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
38
+ dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
39
+ module.register_parameter(name, dist_param)
40
+
41
+
42
+ class RowwiseParallelMaybeWait(RowwiseParallel):
43
+ """
44
+ A version of RowwiseParallel that waits for the output (establish dependency
45
+ between comm stream and compute stream in CUDA sense) before going into the
46
+ next op. This is needed to workaround the current interaction between
47
+ AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
48
+ """
49
+
50
+ @staticmethod
51
+ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
52
+ outputs = super(
53
+ RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
54
+ )._prepare_output_fn(
55
+ output_layouts, use_local_output, mod, outputs, device_mesh
56
+ )
57
+ # wait for the output to be ready
58
+ if isinstance(outputs, AsyncCollectiveTensor):
59
+ return outputs.wait()
60
+ else:
61
+ return outputs
62
+
63
+
64
+ def tensor_parallel(
65
+ module: torch.nn.Module,
66
+ device_mesh: Optional[DeviceMesh] = None,
67
+ ):
68
+ """
69
+ Tensor parallelize the model across the given device mesh.
70
+ Args:
71
+ module (`torch.nn.Module`):
72
+ The module to tensor parallelize.
73
+ device_mesh (`torch.distributed.DeviceMesh`):
74
+ The device mesh to use for tensor parallelism.
75
+ """
76
+
77
+ # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
78
+ # No op if `_tp_plan` attribute does not exist under the module.
79
+ # This is a helper function to be used with `model.apply` to recursively
80
+ # parallelize a model.
81
+ def tplize(mod: torch.nn.Module) -> None:
82
+ tp_plan = getattr(mod, "_tp_plan", None)
83
+ if tp_plan is None:
84
+ return
85
+ for child_name, tp_style in tp_plan.items():
86
+ submod = mod.get_submodule(child_name)
87
+ if tp_style == "Colwise":
88
+ parallelize_module(submod, device_mesh, ColwiseParallel())
89
+ elif tp_style == "Rowwise":
90
+ parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
91
+ elif tp_style == "Colwise_Sharded":
92
+ parallelize_module(submod, device_mesh, ColwiseParallelSharded())
93
+ else:
94
+ raise ValueError(f"Unknown TP style {tp_style}")
95
+
96
+ # `apply` is a native method of `nn.Module` that recursively applies a
97
+ # function to every submodule.
98
+ module.apply(tplize)
@@ -1,22 +1,21 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
16
- # coding=utf-8
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
17
15
  # Adapted from
18
16
  # https://github.com/THUDM/ChatGLM2-6B
19
17
  """Inference-only ChatGLM model compatible with THUDM weights."""
18
+
20
19
  from typing import Iterable, Optional, Tuple
21
20
 
22
21
  import torch
@@ -1,19 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
16
- # coding=utf-8
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
17
14
  # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
18
15
  #
19
16
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
@@ -32,12 +29,14 @@ limitations under the License.
32
29
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33
30
  # See the License for the specific language governing permissions and
34
31
  # limitations under the License.
32
+ # ==============================================================================
35
33
 
36
34
  # Adapted from
37
35
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1
38
36
 
39
37
  # This file is based on the LLama model definition file in transformers
40
38
  """PyTorch Cohere model."""
39
+
41
40
  from typing import Iterable, Optional, Tuple
42
41
 
43
42
  import torch
sglang/srt/models/dbrx.py CHANGED
@@ -1,21 +1,20 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Adapted from:
17
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
18
- # coding=utf-8
17
+
19
18
  from typing import Iterable, Optional, Tuple
20
19
 
21
20
  import torch
@@ -25,11 +24,11 @@ from vllm.distributed import (
25
24
  get_tensor_model_parallel_world_size,
26
25
  tensor_model_parallel_all_reduce,
27
26
  )
28
- from vllm.model_executor.layers.fused_moe import fused_moe
29
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
28
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
29
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
32
30
 
31
+ from sglang.srt.layers.fused_moe_triton import fused_moe
33
32
  from sglang.srt.layers.linear import (
34
33
  QKVParallelLinear,
35
34
  ReplicatedLinear,
@@ -1,21 +1,21 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Adapted from:
17
16
  # https://github.com/vllm-project/vllm/blob/14f91fe67c2342f2fe859dc6a5c40810df0e1c61/vllm/model_executor/models/deepseek.py
18
17
  """Inference-only Deepseek model."""
18
+
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
@@ -26,11 +26,11 @@ from vllm.distributed import (
26
26
  get_tensor_model_parallel_world_size,
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
- from vllm.model_executor.layers.fused_moe import fused_moe
30
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
30
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
31
 
33
32
  from sglang.srt.layers.activation import SiluAndMul
33
+ from sglang.srt.layers.fused_moe_triton import fused_moe
34
34
  from sglang.srt.layers.layernorm import RMSNorm
35
35
  from sglang.srt.layers.linear import (
36
36
  MergedColumnParallelLinear,