tpu-inference 0.11.1rc1__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +0 -0
  49. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,136 @@
1
+ from typing import Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from compressed_tensors.quantization import QuantizationStrategy
7
+ from jax.sharding import NamedSharding, PartitionSpec
8
+ from torchax.interop import jax_view, torch_view
9
+ from vllm.logger import init_logger
10
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
11
+ CompressedTensorsW8A8Int8
12
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
13
+ convert_to_channelwise
14
+
15
+ from tpu_inference.layers.vllm.linear_common import (
16
+ sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
17
+ torch_to_jax_param)
18
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
19
+
20
+ P = PartitionSpec
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
25
+
26
+ def __init__(self, strategy: str, is_static_input_scheme: bool,
27
+ input_symmetric: bool, jax_config: JaxCommonLinearConfig):
28
+ super().__init__(strategy, is_static_input_scheme, input_symmetric)
29
+
30
+ self.jax_config = jax_config
31
+ self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL),
32
+
33
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
34
+ weight = torch_to_jax_param(
35
+ layer.weight,
36
+ NamedSharding(self.jax_config.mesh,
37
+ self.jax_config.weight_sharding),
38
+ self.jax_config.output_sizes,
39
+ self.jax_config.n_shards,
40
+ self.jax_config.fuse_matmuls,
41
+ )
42
+ delattr(layer, "weight")
43
+ layer.weight = weight
44
+
45
+ weight_scale = layer.weight_scale
46
+ is_fused_module = len(layer.logical_widths) > 1
47
+ if is_fused_module and not self.is_channelwise:
48
+ weight_scale = convert_to_channelwise(weight_scale,
49
+ layer.logical_widths)
50
+ weight_scale = weight_scale.squeeze(-1)
51
+
52
+ weight_scale = torch_to_jax_param(
53
+ weight_scale,
54
+ NamedSharding(self.jax_config.mesh, self.jax_config.bias_sharding),
55
+ self.jax_config.output_sizes,
56
+ self.jax_config.n_shards,
57
+ self.jax_config.fuse_matmuls,
58
+ )
59
+ delattr(layer, "weight_scale")
60
+ layer.weight_scale = weight_scale
61
+
62
+ if layer.bias is not None and not layer.skip_bias_add:
63
+ if layer.return_bias:
64
+ logger.warning_once("Bias might return incorrect value.")
65
+
66
+ bias = torch_to_jax_param(
67
+ layer.bias,
68
+ NamedSharding(self.jax_config.mesh,
69
+ self.jax_config.bias_sharding),
70
+ self.jax_config.output_sizes,
71
+ self.jax_config.n_shards,
72
+ self.jax_config.fuse_matmuls,
73
+ )
74
+ delattr(layer, "bias")
75
+ layer.bias = bias
76
+
77
+ # TODO(kyuyeunk): Support static range input quantization.
78
+ assert getattr(layer, "input_scale", None) is None
79
+ assert getattr(layer, "input_zero_point", None) is None
80
+ assert getattr(layer, "azp_adj", None) is None
81
+
82
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
83
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
84
+ with jax.named_scope(layer._get_name()):
85
+ if self.jax_config.fuse_matmuls:
86
+ out = self._apply_fused(layer, x, bias)
87
+ else:
88
+ out = self._apply_split(layer, x, bias)
89
+
90
+ return out
91
+
92
+ def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
93
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
94
+ x_jax = jax_view(x)
95
+ weight_jax = jax_view(layer.weight)
96
+ weight_scale_jax = jax_view(layer.weight_scale)
97
+
98
+ outs = sharded_quantized_matmul(
99
+ x_jax,
100
+ weight_jax,
101
+ weight_scale_jax,
102
+ self.jax_config.mesh,
103
+ self.jax_config.weight_sharding,
104
+ )
105
+ if bias is not None and not layer.skip_bias_add:
106
+ outs += jax_view(bias)
107
+
108
+ outs = slice_sharded_tensor_for_concatenation(
109
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
110
+ out = jnp.concatenate(outs, axis=-1)
111
+ return torch_view(out)
112
+
113
+ def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
114
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
115
+ assert isinstance(layer.weight, torch.nn.ParameterList)
116
+
117
+ x_jax = jax_view(x)
118
+ outs = []
119
+ for i, (weight, weight_scale) in enumerate(
120
+ zip(layer.weight, layer.weight_scale)):
121
+ weight_jax = jax_view(weight)
122
+ weight_scale_jax = jax_view(weight_scale)
123
+
124
+ out = sharded_quantized_matmul(
125
+ x_jax,
126
+ weight_jax,
127
+ weight_scale_jax,
128
+ self.jax_config.mesh,
129
+ self.jax_config.weight_sharding,
130
+ )
131
+ if bias is not None and not layer.skip_bias_add:
132
+ out += jax_view(bias[i])
133
+
134
+ outs.append(out)
135
+ out = jnp.concatenate(outs, axis=-1)
136
+ return torch_view(out)
@@ -0,0 +1,263 @@
1
+ import functools
2
+ from typing import Any, Callable, Optional, Union
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import torch
7
+ from jax.experimental.layout import Format, Layout
8
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
9
+ from torch.nn.parameter import Parameter
10
+ from torchax.interop import jax_view, torch_view
11
+ from torchax.ops.mappings import t2j
12
+ from vllm.attention.layer import Attention
13
+ from vllm.logger import init_logger
14
+ from vllm.model_executor.layers.fused_moe.layer import (
15
+ FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
16
+ from vllm.model_executor.layers.fused_moe.modular_kernel import (
17
+ FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
18
+ from vllm.model_executor.layers.linear import (LinearBase,
19
+ UnquantizedLinearMethod)
20
+ from vllm.model_executor.layers.quantization import \
21
+ register_quantization_config
22
+ from vllm.model_executor.layers.quantization.base_config import (
23
+ QuantizationConfig, QuantizeMethodBase)
24
+
25
+ from tpu_inference.layers.vllm.fused_moe import jax_fused_moe_func_padded
26
+ from tpu_inference.layers.vllm.linear_common import (
27
+ reorder_concatenated_tensor_for_sharding,
28
+ slice_sharded_tensor_for_concatenation, torch_to_jax_param)
29
+ from tpu_inference.layers.vllm.quantization.common import (
30
+ JaxCommonConfig, JaxCommonLinearConfig)
31
+
32
+ P = PartitionSpec
33
+ logger = init_logger(__name__)
34
+
35
+
36
+ @register_quantization_config("jax-unquantized")
37
+ class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
38
+
39
+ @classmethod
40
+ def get_name(cls) -> str:
41
+ return "jax-unquantized"
42
+
43
+ @classmethod
44
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
45
+ return [torch.float32, torch.float16, torch.bfloat16]
46
+
47
+ @classmethod
48
+ def get_min_capability(cls) -> int:
49
+ return 0 # Always supported
50
+
51
+ @classmethod
52
+ def get_config_filenames(cls) -> list[str]:
53
+ return [] # No extra configs required.
54
+
55
+ @classmethod
56
+ def from_config(cls, _: dict[str, Any]) -> "VllmUnquantizedConfig":
57
+ return cls()
58
+
59
+ def get_quant_method(self, layer: torch.nn.Module,
60
+ prefix: str) -> Optional[QuantizeMethodBase]:
61
+ if isinstance(layer, LinearBase):
62
+ linear_config = self.get_linear_config(layer)
63
+ return VllmUnquantizedLinearMethod(linear_config)
64
+ if isinstance(layer, FusedMoE):
65
+ moe_config = self.get_moe_config(layer)
66
+ return VllmUnquantizedFusedMoEMethod(moe_config, self.mesh)
67
+ if isinstance(layer, Attention):
68
+ return None
69
+ return None
70
+
71
+
72
+ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
73
+
74
+ def __init__(self, jax_config: JaxCommonLinearConfig):
75
+ self.jax_config = jax_config
76
+
77
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
78
+ weight = torch_to_jax_param(
79
+ layer.weight,
80
+ NamedSharding(self.jax_config.mesh,
81
+ self.jax_config.weight_sharding),
82
+ self.jax_config.output_sizes,
83
+ self.jax_config.n_shards,
84
+ self.jax_config.fuse_matmuls,
85
+ )
86
+ delattr(layer, "weight")
87
+ layer.weight = weight
88
+
89
+ if layer.bias is not None and not layer.skip_bias_add:
90
+ if layer.return_bias:
91
+ logger.warning_once("Bias might return incorrect value.")
92
+
93
+ bias = torch_to_jax_param(
94
+ layer.bias,
95
+ NamedSharding(self.jax_config.mesh,
96
+ self.jax_config.bias_sharding),
97
+ self.jax_config.output_sizes,
98
+ self.jax_config.n_shards,
99
+ self.jax_config.fuse_matmuls,
100
+ )
101
+ delattr(layer, "bias")
102
+ layer.bias = bias
103
+
104
+ def apply(self,
105
+ layer: torch.nn.Module,
106
+ x: torch.Tensor,
107
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
108
+ with jax.named_scope(layer._get_name()):
109
+ if in_sharding := self.jax_config.get_input_sharding(x):
110
+ x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
111
+
112
+ if self.jax_config.fuse_matmuls:
113
+ out = self._apply_fused(layer, x, bias)
114
+ else:
115
+ out = self._apply_split(layer, x, bias)
116
+
117
+ if out_sharding := self.jax_config.get_output_sharding(out):
118
+ out.shard_(NamedSharding(self.jax_config.mesh, out_sharding))
119
+
120
+ return out
121
+
122
+ def _apply_fused(self,
123
+ layer: torch.nn.Module,
124
+ x: torch.Tensor,
125
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
126
+ x_jax = jax_view(x)
127
+ weight_jax = jax_view(layer.weight)
128
+
129
+ outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
130
+ if bias is not None and not layer.skip_bias_add:
131
+ outs += bias.jax()
132
+
133
+ outs = slice_sharded_tensor_for_concatenation(
134
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
135
+ out = jnp.concatenate(outs, axis=-1)
136
+ return torch_view(out)
137
+
138
+ def _apply_split(self,
139
+ layer: torch.nn.Module,
140
+ x: torch.Tensor,
141
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
142
+ assert isinstance(layer.weight, torch.nn.ParameterList)
143
+
144
+ x_jax = x.jax()
145
+ outs = []
146
+ for i, weight in enumerate(layer.weight):
147
+ weight_jax = jax_view(weight)
148
+
149
+ out = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
150
+ if bias is not None and not layer.skip_bias_add:
151
+ out += jax_view(bias[i])
152
+
153
+ outs.append(out)
154
+ out = jnp.concatenate(outs, axis=-1)
155
+ return torch_view(out)
156
+
157
+
158
+ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
159
+
160
+ def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
161
+ super().__init__(moe)
162
+ self.mesh = mesh
163
+
164
+ def select_gemm_impl(
165
+ self,
166
+ prepare_finalize: FusedMoEPrepareAndFinalize,
167
+ moe: FusedMoEConfig,
168
+ layer: torch.nn.Module,
169
+ ) -> FusedMoEPermuteExpertsUnpermute:
170
+ raise NotImplementedError(
171
+ "Selecting gemm implementation is currently not supported.")
172
+
173
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
174
+ assert isinstance(layer, FusedMoE)
175
+
176
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
177
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
178
+
179
+ if layer.use_ep:
180
+ w13_weight = jax.device_put(
181
+ w13_weight,
182
+ Format(Layout((0, 1, 2)),
183
+ NamedSharding(self.mesh, P("model", None, None))))
184
+ w2_weight = jax.device_put(
185
+ w2_weight,
186
+ Format(Layout((0, 1, 2)),
187
+ NamedSharding(self.mesh, P("model", None, None))))
188
+ else:
189
+ intermediate_size = w13_weight.shape[1] // 2
190
+ assert intermediate_size == w2_weight.shape[-1]
191
+ output_sizes = [intermediate_size, intermediate_size]
192
+ n_shards = self.mesh.shape["model"]
193
+ assert intermediate_size % n_shards == 0
194
+ w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
195
+ output_sizes,
196
+ n_shards,
197
+ dim=1)
198
+ w13_weight = jax.device_put(
199
+ w13_weight,
200
+ Format(Layout((0, 1, 2)),
201
+ NamedSharding(self.mesh, P(None, "model", None))))
202
+ w2_weight = jax.device_put(
203
+ w2_weight,
204
+ Format(Layout((0, 1, 2)),
205
+ NamedSharding(self.mesh, P(None, None, "model"))))
206
+ w13_weight = Parameter(torch_view(w13_weight), requires_grad=False)
207
+ w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
208
+
209
+ layer.w13_weight = w13_weight
210
+ layer.w2_weight = w2_weight
211
+
212
+ def apply(
213
+ self,
214
+ layer: torch.nn.Module,
215
+ x: torch.Tensor,
216
+ router_logits: torch.Tensor,
217
+ top_k: int,
218
+ renormalize: bool,
219
+ use_grouped_topk: bool = False,
220
+ topk_group: Optional[int] = None,
221
+ num_expert_group: Optional[int] = None,
222
+ global_num_experts: int = -1,
223
+ expert_map: Optional[torch.Tensor] = None,
224
+ custom_routing_function: Optional[Callable] = None,
225
+ scoring_func: str = "softmax",
226
+ routed_scaling_factor: float = 1.0,
227
+ e_score_correction_bias: Optional[torch.Tensor] = None,
228
+ apply_router_weight_on_input: bool = False,
229
+ activation: str = "silu",
230
+ enable_eplb: bool = False,
231
+ expert_load_view: Optional[torch.Tensor] = None,
232
+ logical_to_physical_map: Optional[torch.Tensor] = None,
233
+ logical_replica_count: Optional[torch.Tensor] = None,
234
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
235
+ assert isinstance(layer, FusedMoE)
236
+ if activation != "silu":
237
+ raise NotImplementedError(
238
+ "Only silu is supported for activation function.")
239
+ if scoring_func != "softmax":
240
+ raise NotImplementedError(
241
+ "Only softmax is supported for scoring_func")
242
+
243
+ _fused_moe_func = functools.partial(
244
+ jax.jit(jax_fused_moe_func_padded,
245
+ static_argnames=[
246
+ "topk", "global_num_experts", "renormalize",
247
+ "reduce_results", "mesh", "use_ep"
248
+ ]),
249
+ topk=top_k,
250
+ global_num_experts=global_num_experts,
251
+ renormalize=renormalize,
252
+ reduce_results=layer.reduce_results,
253
+ mesh=self.mesh,
254
+ use_ep=layer.use_ep)
255
+
256
+ output = _fused_moe_func(
257
+ jax_view(x),
258
+ jax_view(layer.w13_weight),
259
+ jax_view(layer.w2_weight),
260
+ jax_view(router_logits),
261
+ )
262
+
263
+ return torch_view(output)
@@ -0,0 +1,151 @@
1
+ import jax
2
+ import torch
3
+ import torchax
4
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
5
+ from torch.nn import Parameter
6
+ from torch.utils import _pytree as pytree
7
+ from torchax.interop import jax_view, torch_view
8
+ from torchax.ops.mappings import t2j
9
+ from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA,
10
+ MergedQKVParallelLinearWithLoRA,
11
+ RowParallelLinearWithLoRA)
12
+ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
13
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
14
+ ParallelLMHead, VocabParallelEmbedding)
15
+
16
+ from tpu_inference.logger import init_logger
17
+
18
+ P = PartitionSpec
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ def shard_model_to_tpu(model: torch.nn.Module,
24
+ mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
25
+ """
26
+ Shard the model weights and move them to TPU.
27
+
28
+ At the same time, also turn the weight tensors into torchax tensors so that
29
+ jax code can interop with it and the overall program can be traced and
30
+ compiled in XLA.
31
+
32
+ Args:
33
+ model: A PyTorch model whose weights are on CPU main memory.
34
+ mesh: JAX mesh object for sharding.
35
+ Returns:
36
+ Dictionary of parameters and buffers that will be used as arguments of
37
+ torch.func.functional_call
38
+ """
39
+
40
+ with jax.default_device(jax.devices("cpu")[0]):
41
+ _shard_module_to_tpu(model, mesh)
42
+
43
+ params, buffers = _extract_all_params_buffers(model)
44
+
45
+ # For other weight tensors, repliate them on all the TPU chips.
46
+ params, buffers = pytree.tree_map_only(
47
+ _tensor_is_in_cpu,
48
+ lambda tensor: _shard_tensor_to_tpu_replicated(tensor, mesh),
49
+ (params, buffers))
50
+
51
+ return {**params, **buffers}
52
+
53
+
54
+ def _extract_all_params_buffers(model: torch.nn.Module):
55
+ return dict(model.named_parameters()), dict(model.named_buffers())
56
+
57
+
58
+ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
59
+ # Check if a tensor haven't been converted to torchax tensor.
60
+ if not isinstance(tensor, torchax.tensor.Tensor):
61
+ return True
62
+ # Check if torchax tensor is still in CPU.
63
+ return tensor.jax_device == jax.devices('cpu')[0]
64
+
65
+
66
+ def _convert_to_torchax_and_shard(tensor: torch.Tensor,
67
+ sharding: NamedSharding) -> torch.Tensor:
68
+ if isinstance(tensor, torchax.tensor.Tensor):
69
+ tensor = jax_view(tensor)
70
+ else:
71
+ tensor = t2j(tensor)
72
+ return torch_view(jax.device_put(tensor, sharding))
73
+
74
+
75
+ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
76
+ mesh: Mesh) -> torchax.tensor.Tensor:
77
+ return _convert_to_torchax_and_shard(tensor, NamedSharding(mesh, P()))
78
+
79
+
80
+ def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
81
+ mesh: Mesh) -> None:
82
+ weight = _convert_to_torchax_and_shard(
83
+ layer.weight, NamedSharding(mesh, P('model', None)))
84
+ layer.weight = Parameter(weight, requires_grad=False)
85
+
86
+
87
+ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
88
+ # TODO(qihqi): currently this is not handling case of tie_word_weights=True.
89
+ # if that config is set, then we should not create new weights but reuse the
90
+ # weight from VocabParallelEmbedding
91
+ weight = _convert_to_torchax_and_shard(
92
+ layer.weight, NamedSharding(mesh, P('model', None)))
93
+ layer.weight = Parameter(weight, requires_grad=False)
94
+ if layer.bias is not None:
95
+ bias = _convert_to_torchax_and_shard(layer.bias,
96
+ NamedSharding(mesh, P('model')))
97
+ layer.bias = Parameter(bias, requires_grad=False)
98
+
99
+
100
+ def _shard_base_linear_lora(layer: BaseLinearLayerWithLoRA,
101
+ mesh: Mesh) -> None:
102
+ # NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in]
103
+ sharded_lora_a_tpu = torch.nn.ParameterList()
104
+ sharded_lora_b_tpu = torch.nn.ParameterList()
105
+
106
+ for i in range(layer.n_slices):
107
+ sharded_lora_a_tpu.append(
108
+ _shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))
109
+ sharded_lora_b_tpu.append(
110
+ _shard_tensor_to_tpu_replicated(layer.lora_b_stacked[i], mesh))
111
+
112
+ layer.lora_a_stacked = sharded_lora_a_tpu
113
+ layer.lora_b_stacked = sharded_lora_b_tpu
114
+
115
+
116
+ # TODO: Add custom sharding logic for following lora layers
117
+ def _shard_column_parallel_linear_lora(
118
+ layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None:
119
+ _shard_base_linear_lora(layer, mesh)
120
+
121
+
122
+ def _shard_qkv_parallel_linear_lora(layer: MergedQKVParallelLinearWithLoRA,
123
+ mesh: Mesh) -> None:
124
+ _shard_base_linear_lora(layer, mesh)
125
+
126
+
127
+ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
128
+ mesh: Mesh) -> None:
129
+ _shard_base_linear_lora(layer, mesh)
130
+
131
+
132
+ # NOTE: Ordering is important as it calls first matched type of a given module
133
+ MODULE_TYPE_TO_SHARDING_FUNC = [
134
+ # Shard embedding layers
135
+ (ParallelLMHead, _shard_lm_head),
136
+ (VocabParallelEmbedding, _shard_vocab_parallel_embedding),
137
+ # Shard LoRA layers
138
+ (MergedColumnParallelLinearWithLoRA, _shard_column_parallel_linear_lora),
139
+ (MergedQKVParallelLinearWithLoRA, _shard_qkv_parallel_linear_lora),
140
+ (RowParallelLinearWithLoRA, _shard_row_parallel_linear_lora),
141
+ (BaseLinearLayerWithLoRA, _shard_base_linear_lora),
142
+ ]
143
+
144
+
145
+ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
146
+ for path, module in model.named_modules():
147
+ for module_type, sharding_func in MODULE_TYPE_TO_SHARDING_FUNC:
148
+ if isinstance(module, module_type):
149
+ logger.debug("shard %s with %s", path, sharding_func)
150
+ sharding_func(module, mesh)
151
+ break
File without changes