tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,366 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.sharding import Mesh
7
+ from transformers import LlamaConfig, modeling_flax_utils
8
+ from vllm.config import VllmConfig
9
+
10
+ from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
+ from tpu_inference.layers.jax.attention_interface import attention
13
+ from tpu_inference.layers.jax.rope_interface import apply_rope
14
+ from tpu_inference.logger import init_logger
15
+ from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
16
+ load_hf_weights)
17
+
18
+ logger = init_logger(__name__)
19
+
20
+ init_fn = nnx.initializers.uniform()
21
+
22
+
23
+ class LlamaMLP(nnx.Module):
24
+
25
+ def __init__(self, config: LlamaConfig, dtype: jnp.dtype, rng: nnx.Rngs):
26
+ hidden_size = config.hidden_size
27
+ intermediate_size = config.intermediate_size
28
+ act = config.hidden_act
29
+
30
+ self.gate_proj = nnx.Linear(
31
+ hidden_size,
32
+ intermediate_size,
33
+ use_bias=False,
34
+ param_dtype=dtype,
35
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
36
+ rngs=rng,
37
+ )
38
+ self.up_proj = nnx.Linear(
39
+ hidden_size,
40
+ intermediate_size,
41
+ use_bias=False,
42
+ param_dtype=dtype,
43
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
44
+ rngs=rng,
45
+ )
46
+ self.down_proj = nnx.Linear(
47
+ intermediate_size,
48
+ hidden_size,
49
+ use_bias=False,
50
+ param_dtype=dtype,
51
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
52
+ rngs=rng,
53
+ )
54
+ self.act_fn = modeling_flax_utils.ACT2FN[act]
55
+
56
+ def __call__(self, x: jax.Array) -> jax.Array:
57
+ gate = self.act_fn(self.gate_proj(x))
58
+ up = self.up_proj(x)
59
+ fuse = gate * up
60
+ result = self.down_proj(fuse)
61
+ return result
62
+
63
+
64
+ class LlamaAttention(nnx.Module):
65
+
66
+ def __init__(self, config: LlamaConfig, dtype: jnp.dtype, rng: nnx.Rngs,
67
+ mesh: Mesh, kv_cache_dtype: str):
68
+ self.hidden_size = config.hidden_size
69
+ self.num_heads = config.num_attention_heads
70
+ self.num_kv_heads = config.num_key_value_heads
71
+ self.rope_theta = config.rope_theta
72
+ self.rope_scaling = getattr(config, "rope_scaling", None)
73
+
74
+ self.head_dim_original = getattr(config, "head_dim",
75
+ self.hidden_size // self.num_heads)
76
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
77
+
78
+ sharding_size = mesh.shape["model"]
79
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
80
+ sharding_size)
81
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
82
+ sharding_size)
83
+
84
+ self.mesh = mesh
85
+
86
+ self.q_proj = nnx.Einsum(
87
+ "TD,DNH->TNH",
88
+ (self.hidden_size, self.num_heads, self.head_dim),
89
+ param_dtype=dtype,
90
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
91
+ rngs=rng,
92
+ )
93
+ self.k_proj = nnx.Einsum(
94
+ "TD,DKH->TKH",
95
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
96
+ param_dtype=dtype,
97
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
98
+ rngs=rng,
99
+ )
100
+ self.v_proj = nnx.Einsum(
101
+ "TD,DKH->TKH",
102
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
103
+ param_dtype=dtype,
104
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
105
+ rngs=rng,
106
+ )
107
+ self.o_proj = nnx.Einsum(
108
+ "TNH,NHD->TD",
109
+ (self.num_heads, self.head_dim, self.hidden_size),
110
+ param_dtype=dtype,
111
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
112
+ rngs=rng,
113
+ )
114
+
115
+ self._q_scale = 1.0
116
+ self._k_scale = 1.0
117
+ self._v_scale = 1.0
118
+ self.kv_cache_quantized_dtype = None
119
+ if kv_cache_dtype != "auto":
120
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
121
+ kv_cache_dtype)
122
+
123
+ def __call__(
124
+ self,
125
+ kv_cache: Optional[jax.Array],
126
+ x: jax.Array,
127
+ attention_metadata: AttentionMetadata,
128
+ ) -> Tuple[jax.Array, jax.Array]:
129
+ md = attention_metadata
130
+ # q: (T, N, H)
131
+ q = self.q_proj(x)
132
+ q = apply_rope(q, md.input_positions, self.head_dim_original,
133
+ self.rope_theta, self.rope_scaling)
134
+ # k: (T, K, H)
135
+ k = self.k_proj(x)
136
+ k = apply_rope(k, md.input_positions, self.head_dim_original,
137
+ self.rope_theta, self.rope_scaling)
138
+ # v: (T, K, H)
139
+ v = self.v_proj(x)
140
+ # o: (T, N, H)
141
+ q_scale = k_scale = v_scale = None
142
+ if self.kv_cache_quantized_dtype:
143
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
144
+ # q_scale = self._q_scale
145
+ k_scale = self._k_scale
146
+ v_scale = self._v_scale
147
+ k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
148
+ k_scale, v_scale)
149
+ new_kv_cache, outputs = attention(
150
+ kv_cache,
151
+ q,
152
+ k,
153
+ v,
154
+ attention_metadata,
155
+ self.mesh,
156
+ self.head_dim_original,
157
+ q_scale=q_scale,
158
+ k_scale=k_scale,
159
+ v_scale=v_scale,
160
+ )
161
+ # (T, D)
162
+ o = self.o_proj(outputs)
163
+ return new_kv_cache, o
164
+
165
+
166
+ class LlamaDecoderLayer(nnx.Module):
167
+
168
+ def __init__(self, config: LlamaConfig, dtype: jnp.dtype, rng: nnx.Rngs,
169
+ mesh: Mesh, kv_cache_dtype: str):
170
+ rms_norm_eps = config.rms_norm_eps
171
+ hidden_size = config.hidden_size
172
+
173
+ self.input_layernorm = nnx.RMSNorm(
174
+ hidden_size,
175
+ epsilon=rms_norm_eps,
176
+ param_dtype=dtype,
177
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
178
+ rngs=rng,
179
+ )
180
+ self.self_attn = LlamaAttention(config=config,
181
+ dtype=dtype,
182
+ rng=rng,
183
+ mesh=mesh,
184
+ kv_cache_dtype=kv_cache_dtype)
185
+ self.post_attention_layernorm = nnx.RMSNorm(
186
+ hidden_size,
187
+ epsilon=rms_norm_eps,
188
+ param_dtype=dtype,
189
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
190
+ rngs=rng,
191
+ )
192
+ self.mlp = LlamaMLP(
193
+ config=config,
194
+ dtype=dtype,
195
+ rng=rng,
196
+ )
197
+
198
+ def __call__(
199
+ self,
200
+ kv_cache: jax.Array,
201
+ x: jax.Array,
202
+ attention_metadata: AttentionMetadata,
203
+ ) -> Tuple[jax.Array, jax.Array]:
204
+ hidden_states = self.input_layernorm(x)
205
+ kv_cache, attn_output = self.self_attn(
206
+ kv_cache,
207
+ hidden_states,
208
+ attention_metadata,
209
+ )
210
+ attn_output += x
211
+
212
+ residual = attn_output
213
+ attn_output = self.post_attention_layernorm(attn_output)
214
+ outputs = self.mlp(attn_output)
215
+ outputs = residual + outputs
216
+ return kv_cache, outputs
217
+
218
+
219
+ class LlamaModel(nnx.Module):
220
+
221
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
222
+ mesh: Mesh) -> None:
223
+ model_config = vllm_config.model_config
224
+ hf_config = model_config.hf_config
225
+ vocab_size = model_config.get_vocab_size()
226
+ dtype = model_config.dtype
227
+ rms_norm_eps = hf_config.rms_norm_eps
228
+ hidden_size = hf_config.hidden_size
229
+
230
+ self.embed = nnx.Embed(
231
+ num_embeddings=vocab_size,
232
+ features=hidden_size,
233
+ param_dtype=dtype,
234
+ embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
235
+ rngs=rng,
236
+ )
237
+ self.layers = [
238
+ LlamaDecoderLayer(
239
+ config=hf_config,
240
+ dtype=dtype,
241
+ rng=rng,
242
+ mesh=mesh,
243
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
244
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype)
245
+ for _ in range(hf_config.num_hidden_layers)
246
+ ]
247
+ self.norm = nnx.RMSNorm(
248
+ hidden_size,
249
+ epsilon=rms_norm_eps,
250
+ param_dtype=dtype,
251
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
252
+ rngs=rng,
253
+ )
254
+ if model_config.hf_config.tie_word_embeddings:
255
+ self.lm_head = self.embed.embedding
256
+ else:
257
+ self.lm_head = nnx.Param(
258
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
259
+ sharding=(None, "model"),
260
+ )
261
+
262
+ self.aux_hidden_state_layers = []
263
+ if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3":
264
+ self.aux_hidden_state_layers = self.get_eagle3_aux_hidden_state_layers(
265
+ )
266
+
267
+ def get_eagle3_aux_hidden_state_layers(self):
268
+ num_layers = len(self.layers)
269
+ return (2, num_layers // 2, num_layers - 3)
270
+
271
+ def __call__(
272
+ self,
273
+ kv_caches: List[jax.Array],
274
+ input_ids: jax.Array,
275
+ attention_metadata: AttentionMetadata,
276
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
277
+ x = self.embed(input_ids)
278
+ aux_hidden_states = []
279
+ for i, layer in enumerate(self.layers):
280
+ if i in self.aux_hidden_state_layers:
281
+ aux_hidden_states.append(x)
282
+ kv_cache = kv_caches[i]
283
+ kv_cache, x = layer(
284
+ kv_cache,
285
+ x,
286
+ attention_metadata,
287
+ )
288
+ kv_caches[i] = kv_cache
289
+ x = self.norm(x)
290
+ return kv_caches, x, aux_hidden_states
291
+
292
+
293
+ class LlamaForCausalLM(nnx.Module):
294
+
295
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
296
+ mesh: Mesh) -> None:
297
+ self.vllm_config = vllm_config
298
+ self.rng = nnx.Rngs(rng_key)
299
+ self.mesh = mesh
300
+
301
+ self.model = LlamaModel(
302
+ vllm_config=vllm_config,
303
+ rng=self.rng,
304
+ mesh=mesh,
305
+ )
306
+
307
+ def __call__(
308
+ self,
309
+ kv_caches: List[jax.Array],
310
+ input_ids: jax.Array,
311
+ attention_metadata: AttentionMetadata,
312
+ *args,
313
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
314
+ kv_caches, x, aux_hidden_states = self.model(
315
+ kv_caches,
316
+ input_ids,
317
+ attention_metadata,
318
+ )
319
+ return kv_caches, x, aux_hidden_states
320
+
321
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
322
+ if self.vllm_config.model_config.hf_config.tie_word_embeddings:
323
+ logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
324
+ else:
325
+ logits = jnp.dot(hidden_states, self.model.lm_head.value)
326
+ return logits
327
+
328
+ def load_weights(self, rng_key: jax.Array):
329
+ # NOTE: Since we are using nnx.eval_shape to init the model,
330
+ # we have to pass dynamic arrays here for __call__'s usage.
331
+ self.rng = nnx.Rngs(rng_key)
332
+
333
+ # Key: path to a HF layer weight
334
+ # Value: path to a nnx layer weight
335
+ mappings = {
336
+ "model.embed_tokens": "model.embed.embedding",
337
+ "model.layers.*.input_layernorm":
338
+ "model.layers.*.input_layernorm.scale",
339
+ "model.layers.*.mlp.down_proj":
340
+ "model.layers.*.mlp.down_proj.kernel",
341
+ "model.layers.*.mlp.gate_proj":
342
+ "model.layers.*.mlp.gate_proj.kernel",
343
+ "model.layers.*.mlp.up_proj": "model.layers.*.mlp.up_proj.kernel",
344
+ "model.layers.*.post_attention_layernorm":
345
+ "model.layers.*.post_attention_layernorm.scale",
346
+ "model.layers.*.self_attn.k_proj":
347
+ "model.layers.*.self_attn.k_proj.kernel",
348
+ "model.layers.*.self_attn.o_proj":
349
+ "model.layers.*.self_attn.o_proj.kernel",
350
+ "model.layers.*.self_attn.q_proj":
351
+ "model.layers.*.self_attn.q_proj.kernel",
352
+ "model.layers.*.self_attn.v_proj":
353
+ "model.layers.*.self_attn.v_proj.kernel",
354
+ "model.norm": "model.norm.scale",
355
+ }
356
+ # Add lm_head mapping only if it's not tied to embeddings
357
+ if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
358
+ mappings.update({
359
+ "lm_head": "model.lm_head",
360
+ })
361
+
362
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
363
+ load_hf_weights(vllm_config=self.vllm_config,
364
+ model=self,
365
+ metadata_map=metadata_map,
366
+ mesh=self.mesh)