tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -1,331 +0,0 @@
1
- from typing import Callable, Optional, Union
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import torch
6
- from jax.experimental.layout import Format, Layout
7
- from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
- from torch.nn.parameter import Parameter
9
- from torchax.interop import jax_view, torch_view
10
- from torchax.ops.mappings import t2j
11
- from vllm.logger import init_logger
12
- from vllm.model_executor.layers.fused_moe.config import (
13
- FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
14
- from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
15
- FusedMoEMethodBase)
16
- from vllm.model_executor.layers.linear import LinearBase
17
- from vllm.model_executor.layers.quantization import \
18
- register_quantization_config
19
- from vllm.model_executor.layers.quantization.base_config import \
20
- QuantizeMethodBase
21
- from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
22
- Mxfp4Config,
23
- Mxfp4MoEMethod)
24
- from vllm.model_executor.layers.quantization.utils.quant_utils import \
25
- is_layer_skipped
26
-
27
- from tpu_inference import envs
28
- from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
29
- from tpu_inference.layers.common.quant_methods import (MXFP4,
30
- get_tpu_quant_method)
31
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func
32
- from tpu_inference.layers.vllm.linear_common import \
33
- reorder_concatenated_tensor_for_sharding
34
- from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
35
- from tpu_inference.layers.vllm.quantization.unquantized import \
36
- VllmUnquantizedLinearMethod
37
-
38
- MXFP4_BLOCK_SIZE = 32
39
-
40
- P = PartitionSpec
41
- logger = init_logger(__name__)
42
-
43
-
44
- # TODO(kyuyeunk): Move these functions into a common utility file.
45
- def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
46
- assert u8_packed_e2m1.dtype == jnp.uint8
47
- e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
48
- # bitcast creates one more dimension that splits 8 bits into two e2m1.
49
- # we flatten them with the last dim.
50
- return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
51
-
52
-
53
- def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
54
- e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
55
- exponents = u8.astype(jnp.int32) + e8_finfo.minexp
56
- ones = jnp.ones_like(u8, dtype=jnp.float32)
57
- return jnp.ldexp(ones, exponents)
58
-
59
-
60
- def dequantize_block_weight(weight: jax.Array,
61
- scale: jax.Array,
62
- block_size: int,
63
- out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
64
- orig_shape = weight.shape
65
- weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
66
- weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
67
- scale, -1)
68
- return weight_dequantized.reshape(orig_shape).astype(out_dtype)
69
-
70
-
71
- @register_quantization_config(get_tpu_quant_method(MXFP4))
72
- class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
73
-
74
- @classmethod
75
- def get_name(cls):
76
- return MXFP4
77
-
78
- def get_quant_method(self, layer: torch.nn.Module,
79
- prefix: str) -> Optional["QuantizeMethodBase"]:
80
- from vllm.attention.layer import Attention # Avoid circular import
81
-
82
- if isinstance(layer, LinearBase):
83
- linear_config = self.get_linear_config(layer)
84
- if self.ignored_layers and is_layer_skipped(
85
- prefix=prefix,
86
- ignored_layers=self.ignored_layers,
87
- fused_mapping=self.packed_modules_mapping,
88
- ):
89
- return VllmUnquantizedLinearMethod(linear_config)
90
- logger.warning_once(
91
- "MXFP4 linear layer is not implemented - falling back to "
92
- "UnquantizedLinearMethod.")
93
- return VllmUnquantizedLinearMethod(linear_config)
94
- elif isinstance(layer, FusedMoE):
95
- moe_config = self.get_moe_config(layer)
96
- return VllmMxfp4MoEMethod(moe_config, self.mesh)
97
- elif isinstance(layer, Attention):
98
- logger.warning_once("MXFP4 attention layer is not implemented. "
99
- "Skipping quantization for this layer.")
100
- return None
101
-
102
-
103
- class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
104
-
105
- def __init__(self,
106
- moe: FusedMoEConfig,
107
- mesh: Mesh,
108
- ep_axis_name: str = 'model'):
109
- FusedMoEMethodBase.__init__(self, moe)
110
-
111
- # We piggyback on triton implementation as it applies minimal hardware
112
- # specific post processing to the weights.
113
- self.mxfp4_backend = Mxfp4Backend.TRITON
114
-
115
- self.mesh = mesh
116
- self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
117
- self.ep_axis_name = ep_axis_name
118
- # TODO: Use autotune table once we have it.
119
- self.block_size = {
120
- "bt": 64,
121
- "bf": 1024,
122
- "bd1": 1536,
123
- "bd2": 1536,
124
- "btc": 64,
125
- "bfc": 1024,
126
- "bd1c": 1536,
127
- "bd2c": 1536,
128
- }
129
-
130
- def get_fused_moe_quant_config(
131
- self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
132
- # Because we have dequantized weights, we only need biased moe config.
133
- # TODO(kyuyeunk): Add native support for MXFP4.
134
- return biased_moe_quant_config(
135
- layer.w13_bias,
136
- layer.w2_bias,
137
- )
138
-
139
- def process_weights_after_loading(self, layer: torch.nn.Module):
140
- assert isinstance(layer, FusedMoE)
141
- assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
142
-
143
- w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
144
- w13_weight_scale = e8m0_to_fp32(
145
- t2j(layer.w13_weight_scale, use_dlpack=False))
146
- w13_bias = t2j(layer.w13_bias, use_dlpack=False)
147
-
148
- w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
149
- w2_weight_scale = e8m0_to_fp32(
150
- t2j(layer.w2_weight_scale, use_dlpack=False))
151
- w2_bias = t2j(layer.w2_bias, use_dlpack=False)
152
-
153
- # We dequantize fp4 weights into bf16.
154
- # TODO(kyuyeunk): Add native support for MXFP4.
155
- w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
156
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
157
- w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
158
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
159
-
160
- num_experts, hidden_size, intermediate_size = w2_weight.shape
161
-
162
- # Because we have dequantized weights, scales are not used anymore.
163
- delattr(layer, "w13_weight_scale")
164
- delattr(layer, "w2_weight_scale")
165
-
166
- if layer.activation == "swigluoai":
167
- # When using swigluoai, vLLM splits gmm output in a interleaved way.
168
- # However, interleaved split is not performant on TPU. Therefore,
169
- # we preprocess the weight so that splitting gmm output by middle
170
- # can still get the same result.
171
- w1_weight = w13_weight[:, ::2, :]
172
- w3_weight = w13_weight[:, 1::2, :]
173
- w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
174
-
175
- w1_bias = w13_bias[:, ::2]
176
- w3_bias = w13_bias[:, 1::2]
177
- w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
178
-
179
- if self.use_kernel:
180
- # Kernel expects:
181
- # w13: (num_experts, 2, hidden_size, intermediate_size)
182
- # w2: (num_experts, intermediate_size, hidden_size)
183
- # Current format:
184
- # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
185
- # w2_weight: (num_experts, hidden_size, intermediate_size)
186
-
187
- w13_reshaped = w13_weight.reshape(num_experts, 2,
188
- intermediate_size, hidden_size)
189
-
190
- # Transpose non-constracting dim to right most dim
191
- w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
192
- w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
193
-
194
- # Apply EP sharding
195
- ep_sharding = NamedSharding(self.mesh, P("model"))
196
-
197
- w13_weight = jax.device_put(
198
- w13_weight_transposed, Format(Layout((0, 1, 2, 3)),
199
- ep_sharding))
200
- w2_weight = jax.device_put(w2_weight_transposed,
201
- Format(Layout((0, 1, 2)), ep_sharding))
202
-
203
- w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
204
- w13_bias = jax.device_put(w13_bias,
205
- Format(Layout((0, 1, 2)), ep_sharding))
206
- w2_bias = jax.device_put(w2_bias,
207
- Format(Layout((0, 1)), ep_sharding))
208
-
209
- else:
210
- if layer.use_ep:
211
- ep_sharding = NamedSharding(self.mesh, P("model"))
212
- w13_weight = jax.device_put(
213
- w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
214
- w2_weight = jax.device_put(
215
- w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
216
-
217
- w13_bias = jax.device_put(w13_bias,
218
- Format(Layout((0, 1)), ep_sharding))
219
- w2_bias = jax.device_put(w2_bias,
220
- Format(Layout((0, 1)), ep_sharding))
221
-
222
- else:
223
- output_sizes = [intermediate_size, intermediate_size]
224
- n_shards = self.mesh.shape["model"]
225
- assert intermediate_size % n_shards == 0
226
-
227
- w13_weight = reorder_concatenated_tensor_for_sharding(
228
- w13_weight,
229
- output_sizes,
230
- n_shards,
231
- dim=1,
232
- )
233
- w13_weight = jax.device_put(
234
- w13_weight,
235
- Format(Layout((0, 1, 2)),
236
- NamedSharding(self.mesh, P(None, "model", None))))
237
- w2_weight = jax.device_put(
238
- w2_weight,
239
- Format(Layout((0, 1, 2)),
240
- NamedSharding(self.mesh, P(None, None, "model"))))
241
-
242
- w13_bias = reorder_concatenated_tensor_for_sharding(
243
- w13_bias,
244
- output_sizes,
245
- n_shards,
246
- dim=1,
247
- )
248
- w13_bias = jax.device_put(
249
- w13_bias,
250
- Format(Layout((0, 1)),
251
- NamedSharding(self.mesh, P(None, "model"))))
252
- w2_bias = jax.device_put(
253
- w2_bias,
254
- Format(Layout((0, 1)),
255
- NamedSharding(self.mesh, P(None, None))))
256
-
257
- layer.w13_weight = Parameter(torch_view(w13_weight),
258
- requires_grad=False)
259
- layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
260
-
261
- layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
262
- layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
263
-
264
- pass
265
-
266
- def apply(
267
- self,
268
- layer: torch.nn.Module,
269
- x: torch.Tensor,
270
- router_logits: torch.Tensor,
271
- top_k: int,
272
- renormalize: bool,
273
- use_grouped_topk: bool = False,
274
- topk_group: Optional[int] = None,
275
- num_expert_group: Optional[int] = None,
276
- global_num_experts: int = -1,
277
- expert_map: Optional[torch.Tensor] = None,
278
- custom_routing_function: Optional[Callable] = None,
279
- scoring_func: str = "softmax",
280
- routed_scaling_factor: float = 1.0,
281
- e_score_correction_bias: Optional[torch.Tensor] = None,
282
- apply_router_weight_on_input: bool = False,
283
- activation: str = "silu",
284
- enable_eplb: bool = False,
285
- expert_load_view: Optional[torch.Tensor] = None,
286
- logical_to_physical_map: Optional[torch.Tensor] = None,
287
- logical_replica_count: Optional[torch.Tensor] = None,
288
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
289
- assert isinstance(layer, FusedMoE)
290
- if scoring_func != "softmax":
291
- raise NotImplementedError(
292
- "Only softmax is supported for scoring_func")
293
-
294
- x = jax_view(x)
295
- w13_weight = jax_view(layer.w13_weight)
296
- w2_weight = jax_view(layer.w2_weight)
297
- w13_bias = jax_view(layer.w13_bias)
298
- w2_bias = jax_view(layer.w2_bias)
299
- gating_output = jax_view(router_logits)
300
-
301
- if self.use_kernel:
302
- output = fused_ep_moe(
303
- mesh=self.mesh,
304
- tokens=x,
305
- w1=w13_weight,
306
- w2=w2_weight,
307
- b1=w13_bias,
308
- b2=w2_bias,
309
- gating_output=gating_output,
310
- top_k=top_k,
311
- ep_axis_name=self.ep_axis_name,
312
- renormalize_topk_logits=renormalize,
313
- act_fn=activation,
314
- **self.block_size,
315
- )
316
- else:
317
- output = fused_moe_func(
318
- hidden_states=x,
319
- w1=w13_weight,
320
- w2=w2_weight,
321
- w1_bias=w13_bias,
322
- w2_bias=w2_bias,
323
- gating_output=gating_output,
324
- topk=top_k,
325
- renormalize=renormalize,
326
- mesh=self.mesh,
327
- use_ep=layer.use_ep,
328
- activation=activation,
329
- )
330
-
331
- return torch_view(output)
@@ -1,361 +0,0 @@
1
- import re
2
- from typing import Any, List, Optional, Tuple
3
-
4
- import jax
5
- import jax.numpy as jnp
6
- import torch
7
- from flax import nnx
8
- from flax.typing import PRNGKey
9
- from jax.sharding import Mesh
10
- from jax.sharding import PartitionSpec as P
11
- from vllm.config import VllmConfig
12
-
13
- from tpu_inference.layers.jax.attention.attention import AttentionMetadata
14
- from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
15
- from tpu_inference.layers.jax.constants import KVCacheType
16
- from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
17
- from tpu_inference.layers.jax.misc import shard_put
18
- from tpu_inference.layers.jax.transformer_block import TransformerBlock
19
- from tpu_inference.logger import init_logger
20
- from tpu_inference.models.jax.utils.weight_utils import (
21
- get_param, model_weights_generator, print_param_info, reshape_params,
22
- transpose_params)
23
-
24
- logger = init_logger(__name__)
25
-
26
-
27
- class LlamaGuard4ForCausalLM(nnx.Module):
28
-
29
- def __init__(self,
30
- vllm_config: VllmConfig,
31
- rng: PRNGKey,
32
- mesh: Mesh,
33
- force_random_weights: bool = False):
34
- logger.warning(
35
- "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
36
- "Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
37
- "Multimodal inputs will fail.\n"
38
- "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
39
- assert mesh is not None
40
-
41
- self.vllm_config = vllm_config
42
- self.vllm_config.model_config.dtype = torch.bfloat16
43
- model_config = vllm_config.model_config
44
- text_config = model_config.hf_config.text_config
45
-
46
- self.mesh = mesh
47
- self.is_verbose = getattr(self.vllm_config.additional_config,
48
- "is_verbose", False)
49
-
50
- self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
51
-
52
- vocab_size = model_config.get_vocab_size()
53
- self.hidden_size = model_config.get_hidden_size()
54
-
55
- self.dtype: jnp.dtype = jnp.bfloat16
56
-
57
- self.num_layers: int = getattr(text_config, "num_layers", 48)
58
- hidden_act: str = getattr(text_config, "hidden_act", "silu")
59
-
60
- rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
61
- self.num_attention_heads = getattr(text_config, "num_attention_heads",
62
- 40)
63
- self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
64
- 8)
65
- self.head_dim = getattr(text_config, "head_dim", 128)
66
-
67
- intermediate_size = getattr(text_config, "intermediate_size", 8192)
68
-
69
- self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
70
- self.rope_scaling = getattr(text_config, "rope_scaling")
71
-
72
- self.rng = nnx.Rngs(rng)
73
-
74
- self.embedder = Embedder(
75
- vocab_size=vocab_size,
76
- hidden_size=self.hidden_size,
77
- dtype=self.dtype,
78
- vd_sharding=(('data', 'model'), None),
79
- rngs=self.rng,
80
- random_init=force_random_weights,
81
- )
82
-
83
- self.layers = []
84
-
85
- for i in range(self.num_layers):
86
- use_attention_rope = True
87
-
88
- custom_module = DenseFFW(dtype=self.dtype,
89
- hidden_act=hidden_act,
90
- hidden_size=self.hidden_size,
91
- intermediate_size=intermediate_size,
92
- random_init=force_random_weights,
93
- rngs=self.rng,
94
- df_sharding=P(None, 'model'),
95
- fd_sharding=P('model', None),
96
- activation_ffw_td=P('data', None))
97
-
98
- attn = Llama4Attention(
99
- hidden_size=self.hidden_size,
100
- dtype=self.dtype,
101
- num_attention_heads=self.num_attention_heads,
102
- num_key_value_heads=self.num_key_value_heads,
103
- head_dim=self.head_dim,
104
- rope_theta=self.rope_theta_text,
105
- rope_scaling={
106
- "scale_factor":
107
- self.rope_scaling["factor"],
108
- "low_freq_factor":
109
- self.rope_scaling["low_freq_factor"],
110
- "high_freq_factor":
111
- self.rope_scaling["high_freq_factor"],
112
- "original_max_position_embeddings":
113
- self.rope_scaling["original_max_position_embeddings"]
114
- },
115
- rngs=self.rng,
116
- rope_input_ordering="interleaved",
117
- # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
118
- kv_cache_dtype=vllm_config.cache_config.cache_dtype,
119
- temperature_tuning=True,
120
- temperature_tuning_scale=0.1,
121
- temperature_tuning_floor_scale=8192,
122
- use_qk_norm=self.use_qk_norm,
123
- attention_chunk_size=None if use_attention_rope else 8192,
124
- mesh=self.mesh,
125
- random_init=force_random_weights,
126
- activation_attention_td=('data', 'model'),
127
- activation_q_td=('data', 'model'),
128
- query_tnh=P('data', 'model', None),
129
- keyvalue_skh=P('data', 'model', None),
130
- activation_attention_out_td=('data', 'model'),
131
- attn_o_tnh=P('data', 'model', None),
132
- dnh_sharding=(None, 'model', None),
133
- dkh_sharding=(None, 'model', None),
134
- nhd_sharding=('model', None, None),
135
- )
136
-
137
- pre_attention_norm = RMSNorm(
138
- dims=self.hidden_size,
139
- random_init=force_random_weights,
140
- epsilon=rms_norm_eps,
141
- rngs=self.rng,
142
- activation_ffw_td=('data', None),
143
- with_scale=True,
144
- dtype=self.dtype,
145
- )
146
-
147
- pre_mlp_norm = RMSNorm(
148
- dims=self.hidden_size,
149
- activation_ffw_td=('data', None),
150
- epsilon=rms_norm_eps,
151
- rngs=self.rng,
152
- with_scale=True,
153
- dtype=self.dtype,
154
- random_init=force_random_weights,
155
- )
156
-
157
- block = TransformerBlock(custom_module=custom_module,
158
- attn=attn,
159
- pre_attention_norm=pre_attention_norm,
160
- pre_mlp_norm=pre_mlp_norm,
161
- use_attention_rope=use_attention_rope)
162
- self.layers.append(block)
163
-
164
- self.final_norm = RMSNorm(
165
- dims=self.hidden_size,
166
- activation_ffw_td=P(),
167
- epsilon=rms_norm_eps,
168
- rngs=self.rng,
169
- with_scale=True,
170
- dtype=self.dtype,
171
- random_init=force_random_weights,
172
- )
173
-
174
- self.lm_head = LMhead(vocab_size=vocab_size,
175
- hidden_size=self.hidden_size,
176
- dtype=self.dtype,
177
- rngs=self.rng,
178
- vd_sharding=(('data', 'model'), None),
179
- dv_sharding=(None, ('data', 'model')),
180
- random_init=force_random_weights)
181
- if self.is_verbose:
182
- self._print_model_architecture()
183
-
184
- def _print_model_architecture(self):
185
-
186
- logger.info("### Embedding ###")
187
- nnx.display(self.embedder)
188
-
189
- logger.info("\n### Layers ###")
190
- for i, layer in enumerate(self.layers):
191
- logger.info(f"\n--- Layer {i} ---")
192
- nnx.display(layer)
193
-
194
- logger.info("\n### LM Head ###")
195
- nnx.display(self.lm_head)
196
-
197
- def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
198
- self.rng = nnx.Rngs(rng)
199
-
200
- weight_loader = LlamaGuard4WeightLoader(
201
- vllm_config=self.vllm_config,
202
- hidden_size=self.hidden_size,
203
- attn_heads=self.num_attention_heads,
204
- num_key_value_heads=self.num_key_value_heads,
205
- attn_head_dim=self.head_dim)
206
- weight_loader.load_weights(self)
207
-
208
- def __call__(
209
- self,
210
- kv_caches: List[jax.Array],
211
- input_ids: jax.Array,
212
- attention_metadata: AttentionMetadata,
213
- inputs_embeds: Optional[jax.Array] = None,
214
- layer_metadata_tuple: Optional[Tuple] = None,
215
- lora_metadata: Optional[Any] = None,
216
- *args,
217
- ) -> Tuple[List[KVCacheType], jax.Array]:
218
- is_prefill = False
219
-
220
- if inputs_embeds is not None:
221
- x_TD = inputs_embeds
222
- elif input_ids is not None:
223
- x_TD = self.embedder.encode(input_ids)
224
- else:
225
- raise ValueError(
226
- "Cannot run forward pass: Both input_ids and inputs_embeds are None."
227
- )
228
-
229
- for (i, block) in enumerate(self.layers):
230
- kv_cache = kv_caches[i]
231
- new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
232
- attention_metadata)
233
- jax.block_until_ready(x_TD)
234
- kv_caches[i] = new_kv_cache
235
-
236
- final_activation_TD = self.final_norm(x_TD)
237
-
238
- return kv_caches, final_activation_TD, []
239
-
240
- def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
241
- logits_TV = jnp.dot(hidden_states,
242
- self.lm_head.input_embedding_table_DV.value)
243
- return logits_TV
244
-
245
- def get_input_embeddings(
246
- self,
247
- input_ids: jax.Array,
248
- multimodal_embeddings: Optional[List[jax.Array]] = None
249
- ) -> jax.Array:
250
- """
251
- Computes the embeddings for text input (used for input to fusion).
252
- """
253
- return self.embedder.encode(input_ids)
254
-
255
-
256
- class LlamaGuard4WeightLoader:
257
-
258
- def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
259
- num_key_value_heads, attn_head_dim):
260
- self.names_and_weights_generator = model_weights_generator(
261
- model_name_or_path=vllm_config.model_config.model,
262
- framework="flax",
263
- filter_regex="language_model",
264
- download_dir=vllm_config.load_config.download_dir)
265
- self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
266
- False)
267
- self._transpose_map = {
268
- "q_proj": (2, 0, 1),
269
- "k_proj": (2, 0, 1),
270
- "v_proj": (2, 0, 1),
271
- "o_proj": (1, 2, 0),
272
- "lm_head": (1, 0),
273
- "feed_forward.down_proj": (1, 0),
274
- "feed_forward.gate_proj": (1, 0),
275
- "feed_forward.up_proj": (1, 0),
276
- "mlp.down_proj": (1, 0),
277
- "mlp.gate_proj": (1, 0),
278
- "mlp.up_proj": (1, 0),
279
- }
280
- self._weight_shape_map = {
281
- "q_proj": (attn_heads, attn_head_dim, hidden_size),
282
- "k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
283
- "v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
284
- "o_proj": (hidden_size, attn_heads, attn_head_dim),
285
- }
286
-
287
- self._loaded_to_standardized_keys = {
288
- "language_model.model.embed_tokens.weight":
289
- "embedder.input_embedding_table_VD",
290
- "language_model.lm_head.weight":
291
- "lm_head.input_embedding_table_DV",
292
- "language_model.model.norm.weight":
293
- "final_norm.scale",
294
- "language_model.model.layers.*.input_layernorm.weight":
295
- "layers.*.pre_attention_norm.scale",
296
- "language_model.model.layers.*.post_attention_layernorm.weight":
297
- "layers.*.pre_mlp_norm.scale",
298
- "language_model.model.layers.*.self_attn.q_proj.weight":
299
- "layers.*.attn.kernel_q_proj_DNH",
300
- "language_model.model.layers.*.self_attn.k_proj.weight":
301
- "layers.*.attn.kernel_k_proj_DKH",
302
- "language_model.model.layers.*.self_attn.v_proj.weight":
303
- "layers.*.attn.kernel_v_proj_DKH",
304
- "language_model.model.layers.*.self_attn.o_proj.weight":
305
- "layers.*.attn.kernel_o_proj_NHD",
306
- "language_model.model.layers.*.feed_forward.gate_proj.weight":
307
- "layers.*.custom_module.kernel_gating_DF",
308
- "language_model.model.layers.*.feed_forward.up_proj.weight":
309
- "layers.*.custom_module.kernel_up_proj_DF",
310
- "language_model.model.layers.*.feed_forward.down_proj.weight":
311
- "layers.*.custom_module.kernel_down_proj_FD",
312
- }
313
-
314
- def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
315
- if "layer" in loaded_key:
316
- layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
317
- layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
318
- mapped_key = self._loaded_to_standardized_keys.get(
319
- layer_key, loaded_key)
320
- mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
321
- mapped_key)
322
- else:
323
- mapped_key = self._loaded_to_standardized_keys.get(
324
- loaded_key, loaded_key)
325
- return mapped_key
326
-
327
- def load_weights(self, model_for_loading: nnx.Module):
328
- model_params = nnx.state(model_for_loading)
329
- with jax.default_device(jax.devices("cpu")[0]):
330
- for loaded_name, loaded_weight in self.names_and_weights_generator:
331
- if loaded_name.endswith(".bias"):
332
- continue
333
- if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
334
- continue
335
-
336
- mapped_name = self.map_loaded_to_standardized_name(loaded_name)
337
- model_weight = get_param(model_params, mapped_name)
338
-
339
- if not loaded_name.endswith(".bias"):
340
- # For other layers, continue to use the transpose_params helper.
341
- loaded_weight = reshape_params(loaded_name, loaded_weight,
342
- self._weight_shape_map)
343
- loaded_weight = transpose_params(loaded_name,
344
- loaded_weight,
345
- self._transpose_map)
346
- if model_weight.value.shape != loaded_weight.shape:
347
- raise ValueError(
348
- f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
349
- f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
350
- )
351
- logger.debug(
352
- f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
353
- )
354
-
355
- model_weight.value = shard_put(loaded_weight,
356
- model_weight.sharding,
357
- mesh=model_for_loading.mesh)
358
- if self.is_verbose:
359
- print_param_info(model_weight, loaded_name)
360
-
361
- nnx.update(model_for_loading, model_params)
File without changes