sglang 0.4.0.post1__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_offline_throughput.py +18 -6
  2. sglang/bench_one_batch.py +13 -0
  3. sglang/bench_serving.py +8 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/constrained/xgrammar_backend.py +4 -1
  9. sglang/srt/layers/attention/flashinfer_backend.py +2 -0
  10. sglang/srt/layers/attention/triton_backend.py +16 -25
  11. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  12. sglang/srt/layers/ep_moe/layer.py +4 -0
  13. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  14. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/quantization/__init__.py +2 -47
  17. sglang/srt/layers/quantization/fp8.py +58 -10
  18. sglang/srt/layers/radix_attention.py +8 -1
  19. sglang/srt/layers/sampler.py +27 -5
  20. sglang/srt/layers/torchao_utils.py +35 -0
  21. sglang/srt/managers/detokenizer_manager.py +37 -17
  22. sglang/srt/managers/io_struct.py +39 -10
  23. sglang/srt/managers/schedule_batch.py +38 -24
  24. sglang/srt/managers/schedule_policy.py +64 -5
  25. sglang/srt/managers/scheduler.py +169 -134
  26. sglang/srt/managers/tokenizer_manager.py +99 -58
  27. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  28. sglang/srt/mem_cache/chunk_cache.py +2 -2
  29. sglang/srt/mem_cache/radix_cache.py +12 -2
  30. sglang/srt/model_executor/cuda_graph_runner.py +24 -10
  31. sglang/srt/model_executor/model_runner.py +22 -14
  32. sglang/srt/model_parallel.py +66 -5
  33. sglang/srt/models/gemma2.py +34 -0
  34. sglang/srt/models/gemma2_reward.py +0 -1
  35. sglang/srt/models/granite.py +517 -0
  36. sglang/srt/models/grok.py +72 -8
  37. sglang/srt/models/llama.py +22 -0
  38. sglang/srt/models/llama_classification.py +11 -23
  39. sglang/srt/models/llama_reward.py +0 -2
  40. sglang/srt/models/llava.py +37 -14
  41. sglang/srt/models/qwen2.py +20 -0
  42. sglang/srt/openai_api/adapter.py +4 -0
  43. sglang/srt/openai_api/protocol.py +9 -4
  44. sglang/srt/server.py +1 -1
  45. sglang/srt/server_args.py +19 -9
  46. sglang/srt/utils.py +7 -10
  47. sglang/test/test_utils.py +3 -2
  48. sglang/utils.py +10 -3
  49. sglang/version.py +1 -1
  50. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +11 -6
  51. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +54 -52
  52. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  53. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  54. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
22
22
  import sys
23
23
  import time
24
24
  import uuid
25
- from typing import Dict, List, Optional, Tuple, Union
25
+ from typing import Any, Dict, List, Optional, Union
26
26
 
27
27
  import fastapi
28
28
  import uvloop
@@ -76,6 +76,7 @@ class ReqState:
76
76
  out_list: List
77
77
  finished: bool
78
78
  event: asyncio.Event
79
+ obj: Any
79
80
 
80
81
  # For metrics
81
82
  created_time: float
@@ -283,7 +284,7 @@ class TokenizerManager:
283
284
  ):
284
285
  """Wait for the response of one request."""
285
286
  event = asyncio.Event()
286
- state = ReqState([], False, event, created_time=created_time)
287
+ state = ReqState([], False, event, obj, created_time=created_time)
287
288
  self.rid_to_state[obj.rid] = state
288
289
 
289
290
  while True:
@@ -295,15 +296,7 @@ class TokenizerManager:
295
296
  raise ValueError(f"Abort request {obj.rid}")
296
297
  continue
297
298
 
298
- if isinstance(obj, GenerateReqInput):
299
- out = self.convert_logprob_style(
300
- state.out_list[-1],
301
- obj.return_logprob,
302
- obj.top_logprobs_num,
303
- obj.return_text_in_logprobs,
304
- )
305
- else: # isinstance(obj, (EmbeddingReqInput,))
306
- out = state.out_list[-1]
299
+ out = state.out_list[-1]
307
300
 
308
301
  state.out_list = []
309
302
  if state.finished:
@@ -315,7 +308,13 @@ class TokenizerManager:
315
308
  break
316
309
 
317
310
  state.event.clear()
318
- yield out
311
+
312
+ if obj.stream:
313
+ yield out
314
+ else:
315
+ if request is not None and await request.is_disconnected():
316
+ self.abort_request(obj.rid)
317
+ raise ValueError(f"Abort request {obj.rid}")
319
318
 
320
319
  async def _handle_batch_request(
321
320
  self,
@@ -573,7 +572,7 @@ class TokenizerManager:
573
572
 
574
573
  async def sigterm_watchdog(self):
575
574
  while not self.gracefully_exit:
576
- await asyncio.sleep(60)
575
+ await asyncio.sleep(5)
577
576
 
578
577
  # drain requests
579
578
  while True:
@@ -609,29 +608,55 @@ class TokenizerManager:
609
608
  if state is None:
610
609
  continue
611
610
 
612
- recv_obj.meta_info[i]["id"] = rid
611
+ meta_info = {
612
+ "id": rid,
613
+ "finish_reason": recv_obj.finished_reasons[i],
614
+ "prompt_tokens": recv_obj.prompt_tokens[i],
615
+ }
616
+
617
+ if getattr(state.obj, "return_logprob", False):
618
+ self.convert_logprob_style(
619
+ meta_info,
620
+ state.obj.top_logprobs_num,
621
+ state.obj.return_text_in_logprobs,
622
+ recv_obj,
623
+ i,
624
+ )
625
+
626
+ if not isinstance(recv_obj, BatchEmbeddingOut):
627
+ meta_info.update(
628
+ {
629
+ "completion_tokens": recv_obj.completion_tokens[i],
630
+ "cached_tokens": recv_obj.cached_tokens[i],
631
+ }
632
+ )
633
+
613
634
  if isinstance(recv_obj, BatchStrOut):
614
635
  out_dict = {
615
636
  "text": recv_obj.output_strs[i],
616
- "meta_info": recv_obj.meta_info[i],
637
+ "meta_info": meta_info,
617
638
  }
618
639
  elif isinstance(recv_obj, BatchTokenIDOut):
619
640
  out_dict = {
620
641
  "token_ids": recv_obj.output_ids[i],
621
- "meta_info": recv_obj.meta_info[i],
642
+ "meta_info": meta_info,
622
643
  }
623
644
  else:
624
645
  assert isinstance(recv_obj, BatchEmbeddingOut)
625
646
  out_dict = {
626
647
  "embedding": recv_obj.embeddings[i],
627
- "meta_info": recv_obj.meta_info[i],
648
+ "meta_info": meta_info,
628
649
  }
629
650
  state.out_list.append(out_dict)
630
- state.finished = recv_obj.finished_reason[i] is not None
651
+ state.finished = recv_obj.finished_reasons[i] is not None
631
652
  state.event.set()
632
653
 
633
654
  if self.enable_metrics:
634
- completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
655
+ completion_tokens = (
656
+ recv_obj.completion_tokens[i]
657
+ if recv_obj.completion_tokens
658
+ else 0
659
+ )
635
660
 
636
661
  if state.first_token_time is None:
637
662
  state.first_token_time = time.time()
@@ -647,7 +672,7 @@ class TokenizerManager:
647
672
 
648
673
  if state.finished:
649
674
  self.metrics_collector.inc_prompt_tokens(
650
- recv_obj.meta_info[i]["prompt_tokens"]
675
+ recv_obj.prompt_tokens[i]
651
676
  )
652
677
  self.metrics_collector.inc_generation_tokens(
653
678
  completion_tokens
@@ -696,57 +721,73 @@ class TokenizerManager:
696
721
 
697
722
  def convert_logprob_style(
698
723
  self,
699
- ret: dict,
700
- return_logprob: bool,
724
+ meta_info: dict,
701
725
  top_logprobs_num: int,
702
726
  return_text_in_logprobs: bool,
727
+ recv_obj: BatchStrOut,
728
+ recv_obj_index: int,
703
729
  ):
704
- if return_logprob:
705
- ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
706
- ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
730
+ meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
731
+ recv_obj.input_token_logprobs_val[recv_obj_index],
732
+ recv_obj.input_token_logprobs_idx[recv_obj_index],
733
+ return_text_in_logprobs,
734
+ )
735
+ meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
736
+ recv_obj.output_token_logprobs_val[recv_obj_index],
737
+ recv_obj.output_token_logprobs_idx[recv_obj_index],
738
+ return_text_in_logprobs,
739
+ )
740
+ meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
741
+ recv_obj_index
742
+ ]
743
+
744
+ if top_logprobs_num > 0:
745
+ meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
746
+ recv_obj.input_top_logprobs_val[recv_obj_index],
747
+ recv_obj.input_top_logprobs_idx[recv_obj_index],
748
+ return_text_in_logprobs,
707
749
  )
708
- ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
709
- ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
750
+ meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
751
+ recv_obj.output_top_logprobs_val[recv_obj_index],
752
+ recv_obj.output_top_logprobs_idx[recv_obj_index],
753
+ return_text_in_logprobs,
710
754
  )
711
755
 
712
- if top_logprobs_num > 0:
713
- ret["meta_info"]["input_top_logprobs"] = (
714
- self.detokenize_top_logprobs_tokens(
715
- ret["meta_info"]["input_top_logprobs"],
716
- return_text_in_logprobs,
717
- )
718
- )
719
- ret["meta_info"]["output_top_logprobs"] = (
720
- self.detokenize_top_logprobs_tokens(
721
- ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
722
- )
723
- )
724
- return ret
725
-
726
756
  def detokenize_logprob_tokens(
727
- self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
757
+ self,
758
+ token_logprobs_val: List[float],
759
+ token_logprobs_idx: List[int],
760
+ decode_to_text: bool,
728
761
  ):
729
- # TODO(lianmin): This should run on DetokenizerManager
730
762
  if not decode_to_text:
731
- return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
732
-
733
- assert self.tokenizer is not None
734
- token_ids = [tid for _, tid in token_logprobs]
735
- token_texts = self.tokenizer.batch_decode(token_ids)
736
- return [
737
- (logprob, token_id, token_text)
738
- for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
739
- ]
763
+ return [
764
+ (logprob, token_id, None)
765
+ for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
766
+ ]
767
+ else:
768
+ assert self.tokenizer is not None
769
+ token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
770
+ return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
740
771
 
741
- def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
772
+ def detokenize_top_logprobs_tokens(
773
+ self,
774
+ token_logprobs_val: List[float],
775
+ token_logprobs_idx: List[int],
776
+ decode_to_text: bool,
777
+ ):
742
778
  # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
743
779
  # We should batch all top-k tokens in all positions.
744
- for i, token_top_logprobs in enumerate(top_logprobs):
745
- if token_top_logprobs:
746
- top_logprobs[i] = self.detokenize_logprob_tokens(
747
- token_top_logprobs, decode_to_text
780
+ ret = []
781
+ for i in range(len(token_logprobs_val)):
782
+ if token_logprobs_val[i]:
783
+ ret.append(
784
+ self.detokenize_logprob_tokens(
785
+ token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
786
+ )
748
787
  )
749
- return top_logprobs
788
+ else:
789
+ ret.append(None)
790
+ return ret
750
791
 
751
792
 
752
793
  class SignalHandler:
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Callable
2
+ from typing import Callable, List, Tuple
3
3
 
4
4
 
5
5
  class BasePrefixCache(ABC):
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
10
10
  pass
11
11
 
12
12
  @abstractmethod
13
- def match_prefix(self, **kwargs):
13
+ def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
14
14
  pass
15
15
 
16
16
  @abstractmethod
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Callable, List, Optional
5
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
6
6
 
7
7
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
8
8
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
30
30
  def reset(self):
31
31
  self.entries = {}
32
32
 
33
- def match_prefix(self, rid: int, key: List[int]):
33
+ def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
34
34
  if rid not in self.entries:
35
35
  return [], None
36
36
 
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
22
22
  import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
- from typing import TYPE_CHECKING, Callable, List, Optional
25
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
26
26
 
27
27
  import torch
28
28
 
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
76
76
  self.root_node.lock_ref = 1
77
77
  self.evictable_size_ = 0
78
78
 
79
- def match_prefix(self, key: List, **kwargs):
79
+ def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
80
+ """Find the matching prefix from the radix tree.
81
+ Args:
82
+ key: A list of token IDs to find a matching prefix.
83
+ Returns:
84
+ A tuple of a tensor of matching prefix token IDs and
85
+ the last node that contains the prefix values. Note that
86
+ this API can modify the internal state of the Radix tree.
87
+ The last node create a new child if the prefix is shorter
88
+ than the last node's value.
89
+ """
80
90
  if self.disable:
81
91
  return [], self.root_node
82
92
 
@@ -20,6 +20,8 @@ from contextlib import contextmanager
20
20
  from typing import TYPE_CHECKING, Callable
21
21
 
22
22
  import torch
23
+ import tqdm
24
+ from vllm.distributed import get_tensor_model_parallel_rank
23
25
  from vllm.distributed.parallel_state import graph_capture
24
26
  from vllm.model_executor.custom_op import CustomOp
25
27
 
@@ -127,7 +129,7 @@ class CudaGraphRunner:
127
129
 
128
130
  # Batch sizes to capture
129
131
  if model_runner.server_args.disable_cuda_graph_padding:
130
- self.capture_bs = list(range(1, 32)) + [64, 128]
132
+ self.capture_bs = list(range(1, 33)) + [64, 128]
131
133
  else:
132
134
  self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
133
135
 
@@ -255,7 +257,12 @@ class CudaGraphRunner:
255
257
  def capture(self):
256
258
  with graph_capture() as graph_capture_context:
257
259
  self.stream = graph_capture_context.stream
258
- for bs in self.capture_bs:
260
+ capture_bs = (
261
+ tqdm.tqdm(self.capture_bs)
262
+ if get_tensor_model_parallel_rank() == 0
263
+ else self.capture_bs
264
+ )
265
+ for bs in capture_bs:
259
266
  with patch_model(
260
267
  self.model_runner.model,
261
268
  bs in self.compile_bs,
@@ -387,8 +394,14 @@ class CudaGraphRunner:
387
394
 
388
395
  # Extract logprobs
389
396
  if forward_batch.return_logprob:
390
- next_token_logprobs = torch.nn.functional.log_softmax(
391
- next_token_logits, dim=-1
397
+ logits_metadata = LogitsMetadata(
398
+ forward_mode=ForwardMode.DECODE,
399
+ top_logprobs_nums=forward_batch.top_logprobs_nums,
400
+ )
401
+ next_token_logprobs = (
402
+ LogitsProcessor.compute_temp_top_p_normalized_logprobs(
403
+ next_token_logits, logits_metadata
404
+ )
392
405
  )
393
406
  logits_output = LogitsProcessorOutput(
394
407
  next_token_logits=next_token_logits,
@@ -396,13 +409,14 @@ class CudaGraphRunner:
396
409
  )
397
410
  return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
398
411
  if return_top_logprob:
399
- logits_metadata = LogitsMetadata(
400
- forward_mode=ForwardMode.DECODE,
401
- top_logprobs_nums=forward_batch.top_logprobs_nums,
402
- )
403
- logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
412
+ (
413
+ logits_output.output_top_logprobs_val,
414
+ logits_output.output_top_logprobs_idx,
415
+ ) = LogitsProcessor.get_top_logprobs(
404
416
  next_token_logprobs, logits_metadata
405
- )[1]
417
+ )[
418
+ 2:4
419
+ ]
406
420
  else:
407
421
  logits_output = LogitsProcessorOutput(
408
422
  next_token_logits=next_token_logits,
@@ -111,17 +111,20 @@ class ModelRunner:
111
111
  )
112
112
 
113
113
  if self.is_multimodal:
114
- server_args.chunked_prefill_size = -1
115
114
  self.mem_fraction_static *= 0.95
116
- logger.info(
117
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
118
- f"and turn off chunked prefill "
119
- f"because this is a multimodal model."
120
- )
115
+ if self.model_config.hf_config.architectures == [
116
+ "MllamaForConditionalGeneration"
117
+ ]:
118
+ logger.info("Automatically turn off --chunked-prefill-size for mllama.")
119
+ server_args.chunked_prefill_size = -1
121
120
  # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
122
121
  if self.model_config.hf_config.architectures == [
123
122
  "Qwen2VLForConditionalGeneration"
124
123
  ]:
124
+ logger.info(
125
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
126
+ )
127
+ server_args.chunked_prefill_size = -1
125
128
  server_args.disable_radix_cache = True
126
129
 
127
130
  # Global vars
@@ -154,6 +157,11 @@ class ModelRunner:
154
157
  self.sampler = Sampler()
155
158
  self.load_model()
156
159
 
160
+ # Apply torchao quantization
161
+ apply_torchao_config_to_model(
162
+ self.model, global_server_args_dict["torchao_config"]
163
+ )
164
+
157
165
  # Apply torch TP if the model supports it
158
166
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
159
167
  if self.tp_size > 1 and supports_torch_tp:
@@ -162,10 +170,6 @@ class ModelRunner:
162
170
  else:
163
171
  self.torch_tp_applied = False
164
172
 
165
- apply_torchao_config_to_model(
166
- self.model, global_server_args_dict["torchao_config"]
167
- )
168
-
169
173
  # Init memory pool and attention backends
170
174
  if server_args.lora_paths is not None:
171
175
  self.init_lora_manager()
@@ -242,20 +246,22 @@ class ModelRunner:
242
246
  if torch.cuda.get_device_capability()[1] < 5:
243
247
  raise RuntimeError("SGLang only supports sm75 and above.")
244
248
 
245
- # Prepare the vllm model config
249
+ # Prepare the model config
246
250
  self.load_config = LoadConfig(
247
251
  load_format=self.server_args.load_format,
248
252
  download_dir=self.server_args.download_dir,
249
253
  )
250
-
251
254
  if self.server_args.load_format == "gguf":
252
255
  monkey_patch_vllm_gguf_config()
256
+
257
+ # Load the model
253
258
  self.model = get_model(
254
259
  model_config=self.model_config,
255
260
  load_config=self.load_config,
256
261
  device_config=DeviceConfig(self.device),
257
262
  )
258
263
 
264
+ # Parse other args
259
265
  self.sliding_window_size = (
260
266
  self.model.get_attention_sliding_window_size()
261
267
  if hasattr(self.model, "get_attention_sliding_window_size")
@@ -270,8 +276,10 @@ class ModelRunner:
270
276
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
271
277
  )
272
278
 
273
- def update_weights_from_disk(self, model_path: str, load_format: str):
274
- """Update engine weights online from disk."""
279
+ def update_weights_from_disk(
280
+ self, model_path: str, load_format: str
281
+ ) -> tuple[bool, str]:
282
+ """Update engine weights in-place from the disk."""
275
283
  from sglang.srt.model_loader.loader import (
276
284
  DefaultModelLoader,
277
285
  device_loading_context,
@@ -2,18 +2,18 @@
2
2
  Common utilities for torch model parallelism.
3
3
  """
4
4
 
5
- from typing import Optional
5
+ from typing import Optional, Sequence
6
6
 
7
7
  import torch
8
+ import torch.nn as nn
8
9
  from torch.distributed.device_mesh import DeviceMesh
9
10
 
10
11
  try:
11
- from torch.distributed.tensor import DTensor, Shard
12
+ import torch.distributed.tensor as dt
12
13
  except ImportError:
13
14
  # torch 2.4 or older
14
- from torch.distributed._tensor import DTensor, Shard
15
+ import torch.distributed._tensor as dt
15
16
 
16
- from torch.distributed._functional_collectives import AsyncCollectiveTensor
17
17
  from torch.distributed.tensor.parallel import (
18
18
  ColwiseParallel,
19
19
  RowwiseParallel,
@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import (
21
21
  )
22
22
 
23
23
 
24
+ def _shard_tensor(
25
+ full_tensor: torch.Tensor,
26
+ device_mesh: DeviceMesh,
27
+ placements: Sequence[dt.Shard],
28
+ ) -> "dt.DTensor":
29
+ """
30
+ Locally shards a full tensor based on indicated sharding arrangement, and
31
+ returns a DTensor containing the local shard.
32
+
33
+ .. warning:: This is a private API that is subject to change. It skips the
34
+ communication otherwise required by `distribute_tensor`. It is only
35
+ applicable to cases where all ranks have the same `full_tensor`. For
36
+ example, in distributed inference all ranks load from the same
37
+ checkpoint. This API will not check for data equality between ranks, it
38
+ is thus user's responsibility to ensure the `full_tensor` is the same
39
+ across ranks.
40
+
41
+ Args:
42
+ full_tensor (torch.Tensor): the full tensor to be sharded.
43
+ device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
44
+ DTensor. Must have same dimension as the number of placements.
45
+ placements (Sequence[:class:`Shard`]): the placements that
46
+ describes how to place the local tensor on DeviceMesh.
47
+
48
+ Returns:
49
+ A :class:`DTensor` object with the shard as its local tensor.
50
+
51
+ Examples:
52
+ >>> # xdoctest: +SKIP("need world_size and rank")
53
+ >>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
54
+ >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
55
+ >>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
56
+ """
57
+ shape, offset = dt._utils.compute_local_shape_and_global_offset(
58
+ full_tensor.shape, device_mesh, placements
59
+ )
60
+ slices = [
61
+ slice(cur_offset, cur_offset + cur_shape)
62
+ for cur_shape, cur_offset in zip(shape, offset)
63
+ ]
64
+ local_tensor = full_tensor[slices]
65
+ return dt.DTensor.from_local(local_tensor, device_mesh, placements)
66
+
67
+
24
68
  class ColwiseParallelSharded(ColwiseParallel):
25
69
  """
26
70
  A version of ColwiseParallel where the local weight has been already
@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel):
34
78
  # means Colwise as Linear is input * weight^T + bias, where
35
79
  # weight would become Shard(1)
36
80
  for name, param in module.named_parameters():
37
- dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
81
+ dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
38
82
  dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
39
83
  module.register_parameter(name, dist_param)
40
84
 
@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
47
91
  AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
48
92
  """
49
93
 
94
+ def _partition_linear_fn(self, name, module, device_mesh):
95
+ # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
96
+ # means Rowwise as nn.Linear is input * weight^T + bias, where
97
+ # weight would become Shard(0)
98
+ module.register_parameter(
99
+ "weight",
100
+ nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
101
+ )
102
+ if getattr(module, "bias", None) is not None:
103
+ # The Linear module has bias
104
+ module.register_parameter(
105
+ "bias",
106
+ nn.Parameter(
107
+ dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
108
+ ),
109
+ )
110
+
50
111
  @staticmethod
51
112
  def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
52
113
  outputs = super(
@@ -355,6 +355,40 @@ class Gemma2ForCausalLM(nn.Module):
355
355
  input_ids, hidden_states, self.model.embed_tokens, forward_batch
356
356
  )
357
357
 
358
+ def get_hidden_dim(self, module_name):
359
+ # return input_dim, output_dim
360
+ if module_name in ["q_proj", "qkv_proj"]:
361
+ return (
362
+ self.config.hidden_size,
363
+ self.config.head_dim * self.config.num_attention_heads,
364
+ )
365
+ elif module_name in ["o_proj"]:
366
+ return (
367
+ self.config.head_dim * self.config.num_attention_heads,
368
+ self.config.hidden_size,
369
+ )
370
+ elif module_name in ["kv_proj"]:
371
+ return (
372
+ self.config.hidden_size,
373
+ self.config.head_dim * self.config.num_key_value_heads,
374
+ )
375
+ elif module_name == "gate_up_proj":
376
+ return self.config.hidden_size, self.config.intermediate_size
377
+ elif module_name == "down_proj":
378
+ return self.config.intermediate_size, self.config.hidden_size
379
+ else:
380
+ raise NotImplementedError()
381
+
382
+ def get_module_name(self, name):
383
+ params_mapping = {
384
+ "q_proj": "qkv_proj",
385
+ "k_proj": "qkv_proj",
386
+ "v_proj": "qkv_proj",
387
+ "gate_proj": "gate_up_proj",
388
+ "up_proj": "gate_up_proj",
389
+ }
390
+ return params_mapping.get(name, name)
391
+
358
392
  def get_attention_sliding_window_size(self):
359
393
  return get_attention_sliding_window_size(self.config)
360
394
 
@@ -32,7 +32,6 @@ class Gemma2ForSequenceClassification(nn.Module):
32
32
  ) -> None:
33
33
  super().__init__()
34
34
  self.config = config
35
- self.torchao_config = None
36
35
  self.quant_config = quant_config
37
36
  self.num_labels = config.num_labels
38
37
  self.model = Gemma2Model(config, quant_config=quant_config)