tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,333 @@
1
+ from typing import List, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.sharding import Mesh
7
+ from transformers import LlamaConfig
8
+ from vllm.config import VllmConfig
9
+
10
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
11
+ from tpu_inference.logger import init_logger
12
+ from tpu_inference.models.jax.llama3 import LlamaDecoderLayer
13
+ from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
14
+ get_default_maps,
15
+ load_hf_weights)
16
+
17
+ logger = init_logger(__name__)
18
+
19
+ init_fn = nnx.initializers.uniform()
20
+
21
+
22
+ class Eagle3LlamaDecoderLayer(LlamaDecoderLayer):
23
+
24
+ def __init__(self, config: LlamaConfig, dtype: jnp.dtype, rng: nnx.Rngs,
25
+ mesh: Mesh, kv_cache_dtype: str):
26
+ super().__init__(config,
27
+ dtype=dtype,
28
+ rng=rng,
29
+ mesh=mesh,
30
+ kv_cache_dtype=kv_cache_dtype)
31
+ self.config = config
32
+ # Override qkv
33
+ hidden_size = 2 * self.self_attn.hidden_size
34
+ self.self_attn.q_proj = nnx.Einsum(
35
+ "TD,DNH->TNH",
36
+ (hidden_size, self.self_attn.num_heads, self.self_attn.head_dim),
37
+ param_dtype=dtype,
38
+ dtype=dtype,
39
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
40
+ rngs=rng,
41
+ )
42
+ self.self_attn.k_proj = nnx.Einsum(
43
+ "TD,DKH->TKH",
44
+ (hidden_size, self.self_attn.num_kv_heads,
45
+ self.self_attn.head_dim),
46
+ param_dtype=dtype,
47
+ dtype=dtype,
48
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
49
+ rngs=rng,
50
+ )
51
+ self.self_attn.v_proj = nnx.Einsum(
52
+ "TD,DKH->TKH",
53
+ (hidden_size, self.self_attn.num_kv_heads,
54
+ self.self_attn.head_dim),
55
+ param_dtype=dtype,
56
+ dtype=dtype,
57
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
58
+ rngs=rng,
59
+ )
60
+ # Override input layernorm and specify dtype to avoid unexpected upcasting.
61
+ self.input_layernorm = nnx.RMSNorm(
62
+ config.hidden_size,
63
+ epsilon=config.rms_norm_eps,
64
+ param_dtype=dtype,
65
+ dtype=dtype,
66
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
67
+ rngs=rng,
68
+ )
69
+ self.hidden_norm = nnx.RMSNorm(
70
+ config.hidden_size,
71
+ epsilon=config.rms_norm_eps,
72
+ param_dtype=dtype,
73
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
74
+ rngs=rng,
75
+ )
76
+
77
+ def _norm_before_residual(
78
+ self, hidden_states: jax.Array) -> tuple[jax.Array, jax.Array]:
79
+ hidden_states = self.hidden_norm(hidden_states)
80
+ residual = hidden_states
81
+ return hidden_states, residual
82
+
83
+ def _norm_after_residual(
84
+ self, hidden_states: jax.Array) -> tuple[jax.Array, jax.Array]:
85
+ residual = hidden_states
86
+ hidden_states = self.hidden_norm(hidden_states)
87
+ return hidden_states, residual
88
+
89
+ def __call__(
90
+ self,
91
+ kv_cache: jax.Array,
92
+ embeds: jax.Array,
93
+ hidden_states: jax.Array,
94
+ attention_metadata: AttentionMetadata,
95
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
96
+ embeds = self.input_layernorm(embeds)
97
+ if getattr(self.config, "norm_before_residual", False):
98
+ hidden_states, residual = self._norm_before_residual(
99
+ hidden_states=hidden_states)
100
+ else:
101
+ hidden_states, residual = self._norm_after_residual(
102
+ hidden_states=hidden_states)
103
+ hidden_states = jnp.concatenate([embeds, hidden_states], axis=-1)
104
+
105
+ kv_cache, attn_output = self.self_attn(
106
+ kv_cache,
107
+ hidden_states,
108
+ attention_metadata,
109
+ )
110
+
111
+ # TODO(ranlihao): Check if this residual connection is correct.
112
+ hidden_states = attn_output + residual
113
+ residual = hidden_states
114
+ hidden_states = self.post_attention_layernorm(hidden_states)
115
+ mlp_output = self.mlp(hidden_states)
116
+
117
+ return kv_cache, mlp_output, residual
118
+
119
+
120
+ class Eagle3LlamaModel(nnx.Module):
121
+
122
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, mesh: Mesh):
123
+ super().__init__()
124
+ hf_config = vllm_config.speculative_config.draft_model_config.hf_config
125
+ dtype: jnp.dtype = jnp.bfloat16
126
+
127
+ self.embed_tokens = nnx.Embed(
128
+ num_embeddings=hf_config.vocab_size,
129
+ features=hf_config.hidden_size,
130
+ param_dtype=dtype,
131
+ embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
132
+ rngs=rng,
133
+ )
134
+
135
+ self.layers = [
136
+ Eagle3LlamaDecoderLayer(
137
+ config=hf_config,
138
+ dtype=dtype,
139
+ rng=rng,
140
+ mesh=mesh,
141
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
142
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype)
143
+ ]
144
+
145
+ if hasattr(hf_config, "target_hidden_size"):
146
+ input_size = hf_config.target_hidden_size * 3
147
+ else:
148
+ input_size = hf_config.hidden_size * 3
149
+
150
+ self.fc = nnx.Linear(
151
+ in_features=input_size,
152
+ out_features=hf_config.hidden_size,
153
+ use_bias=False,
154
+ param_dtype=dtype,
155
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
156
+ rngs=rng,
157
+ )
158
+
159
+ self.norm = nnx.RMSNorm(
160
+ hf_config.hidden_size,
161
+ epsilon=hf_config.rms_norm_eps,
162
+ param_dtype=dtype,
163
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
164
+ rngs=rng,
165
+ )
166
+
167
+ def __call__(
168
+ self,
169
+ kv_caches: List[jax.Array],
170
+ input_ids: jax.Array,
171
+ hidden_states: jax.Array,
172
+ attention_metadata: AttentionMetadata,
173
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
174
+ embeds = self.embed_tokens(input_ids)
175
+ assert hidden_states.shape[-1] == embeds.shape[-1]
176
+
177
+ assert len(self.layers) == 1
178
+ # The first N - 1 KV caches are for the target model, and the last one is for the draft model.
179
+ # N is the number of layers in the target model.
180
+ # The draft model has only 1 layer.
181
+ kv_caches[-1], hidden_states, residual = self.layers[0](
182
+ kv_caches[-1],
183
+ embeds,
184
+ hidden_states,
185
+ attention_metadata,
186
+ )
187
+
188
+ # TODO(ranlihao): Check if this residual connection is correct.
189
+ hidden_states = hidden_states + residual
190
+ residual = hidden_states
191
+ hidden_states = self.norm(hidden_states)
192
+ return kv_caches, hidden_states, [residual]
193
+
194
+
195
+ def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
196
+ metadata_map: MetadataMap):
197
+ model_config = vllm_config.model_config
198
+ hf_config = model_config.hf_config
199
+
200
+ num_heads = hf_config.num_attention_heads
201
+ num_kv_heads = hf_config.num_key_value_heads
202
+ hidden_size = model_config.get_hidden_size()
203
+
204
+ head_dim_original = model_config.get_head_size()
205
+
206
+ metadata_map.reshape_map.update({
207
+ "q_proj": (num_heads, head_dim_original, 2 * hidden_size),
208
+ "k_proj": (num_kv_heads, head_dim_original, 2 * hidden_size),
209
+ "v_proj": (num_kv_heads, head_dim_original, 2 * hidden_size),\
210
+ })
211
+
212
+
213
+ class EagleLlama3ForCausalLM(nnx.Module):
214
+
215
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
216
+ mesh: Mesh):
217
+ nnx.Module.__init__(self)
218
+ self.vllm_config = vllm_config
219
+ self.rng = nnx.Rngs(rng_key)
220
+ self.mesh = mesh
221
+ dtype: jnp.dtype = jnp.bfloat16
222
+
223
+ spec_config = vllm_config.speculative_config
224
+ assert spec_config is not None
225
+ model_config = spec_config.draft_model_config
226
+ assert model_config is not None
227
+ hf_config = model_config.hf_config
228
+
229
+ self.model = Eagle3LlamaModel(
230
+ vllm_config=vllm_config,
231
+ rng=self.rng,
232
+ mesh=mesh,
233
+ )
234
+
235
+ self.lm_head = nnx.Linear(
236
+ hf_config.hidden_size,
237
+ hf_config.draft_vocab_size,
238
+ use_bias=False,
239
+ param_dtype=dtype,
240
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
241
+ rngs=self.rng,
242
+ )
243
+
244
+ self.draft_id_to_target_id = nnx.Param(jnp.zeros(
245
+ hf_config.draft_vocab_size, dtype=jnp.int32),
246
+ sharding=(None, ))
247
+
248
+ def __call__(
249
+ self,
250
+ kv_caches: List[jax.Array],
251
+ input_ids: jax.Array,
252
+ hidden_states: jax.Array,
253
+ attention_metadata: AttentionMetadata,
254
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
255
+ return self.model(
256
+ kv_caches,
257
+ input_ids,
258
+ hidden_states,
259
+ attention_metadata,
260
+ )
261
+
262
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
263
+ logits = self.lm_head(hidden_states)
264
+
265
+ target_vocab_size = self.vllm_config.model_config.get_vocab_size()
266
+ draft_vocab_size = self.vllm_config.speculative_config.draft_model_config.hf_config.draft_vocab_size
267
+
268
+ base = jnp.arange(draft_vocab_size, dtype=jnp.int32)
269
+ targets = base + self.draft_id_to_target_id.value
270
+
271
+ logits_new = jnp.full((logits.shape[0], target_vocab_size),
272
+ -jnp.inf,
273
+ dtype=logits.dtype)
274
+
275
+ logits_new = logits_new.at[:, targets].set(logits)
276
+
277
+ return logits_new
278
+
279
+ def combine_hidden_states(self, hidden_states: jax.Array) -> jax.Array:
280
+ return self.model.fc(hidden_states)
281
+
282
+ def load_weights(self, rng_key: jax.Array):
283
+ # Create a new Rngs object for the draft model to avoid sharing RNG state
284
+ self.rng = jax.random.key(self.vllm_config.model_config.seed)
285
+ spec_config = self.vllm_config.speculative_config
286
+ assert spec_config is not None
287
+
288
+ mappings = {
289
+ "midlayer.input_layernorm": "model.layers.0.input_layernorm.scale",
290
+ "midlayer.hidden_norm": "model.layers.0.hidden_norm.scale",
291
+ "midlayer.mlp.down_proj": "model.layers.0.mlp.down_proj.kernel",
292
+ "midlayer.mlp.gate_proj": "model.layers.0.mlp.gate_proj.kernel",
293
+ "midlayer.mlp.up_proj": "model.layers.0.mlp.up_proj.kernel",
294
+ "midlayer.post_attention_layernorm":
295
+ "model.layers.0.post_attention_layernorm.scale",
296
+ "midlayer.self_attn.k_proj":
297
+ "model.layers.0.self_attn.k_proj.kernel",
298
+ "midlayer.self_attn.o_proj":
299
+ "model.layers.0.self_attn.o_proj.kernel",
300
+ "midlayer.self_attn.q_proj":
301
+ "model.layers.0.self_attn.q_proj.kernel",
302
+ "midlayer.self_attn.v_proj":
303
+ "model.layers.0.self_attn.v_proj.kernel",
304
+ "norm": "model.norm.scale",
305
+ "fc": "model.fc.kernel",
306
+ "lm_head": "lm_head.kernel",
307
+ "d2t": "draft_id_to_target_id",
308
+ }
309
+
310
+ # Define keys to keep in original dtype (e.g., float32 for stability)
311
+ keep_original_dtype_keys_regex = [
312
+ r".*d2t.*",
313
+ ]
314
+
315
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
316
+
317
+ update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
318
+
319
+ load_hf_weights(
320
+ vllm_config=self.vllm_config,
321
+ model=self,
322
+ metadata_map=metadata_map,
323
+ mesh=self.mesh,
324
+ is_draft_model=True,
325
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
326
+
327
+ # If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
328
+ if isinstance(self.model.embed_tokens.embedding.value,
329
+ jax.ShapeDtypeStruct):
330
+ self.model.embed_tokens.embedding.value = jnp.zeros(
331
+ self.model.embed_tokens.embedding.shape,
332
+ dtype=self.model.embed_tokens.embedding.dtype,
333
+ )
@@ -0,0 +1,376 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.sharding import Mesh
7
+ from transformers import Phi3Config, modeling_flax_utils
8
+ from vllm.config import VllmConfig
9
+
10
+ from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
+ from tpu_inference.layers.jax.attention_interface import attention
13
+ from tpu_inference.layers.jax.rope_interface import apply_longrope, apply_rope
14
+ from tpu_inference.logger import init_logger
15
+ from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
16
+ load_hf_weights)
17
+
18
+ logger = init_logger(__name__)
19
+
20
+ init_fn = nnx.initializers.uniform()
21
+
22
+
23
+ class Phi3MLP(nnx.Module):
24
+
25
+ def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs):
26
+ hidden_size = config.hidden_size
27
+ intermediate_size = config.intermediate_size
28
+ act = config.hidden_act
29
+
30
+ self.gate_up_proj = nnx.Linear(
31
+ hidden_size,
32
+ 2 * intermediate_size,
33
+ use_bias=False,
34
+ param_dtype=dtype,
35
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
36
+ rngs=rng,
37
+ )
38
+ self.down_proj = nnx.Linear(
39
+ intermediate_size,
40
+ hidden_size,
41
+ use_bias=False,
42
+ param_dtype=dtype,
43
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
44
+ rngs=rng,
45
+ )
46
+ self.act_fn = modeling_flax_utils.ACT2FN[act]
47
+
48
+ def __call__(self, x: jax.Array) -> jax.Array:
49
+ gate_up = self.gate_up_proj(x)
50
+ gate, up = jnp.split(gate_up, 2, axis=-1)
51
+ fuse = up * self.act_fn(gate)
52
+ result = self.down_proj(fuse)
53
+ return result
54
+
55
+
56
+ class Phi3Attention(nnx.Module):
57
+
58
+ def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
59
+ mesh: Mesh, kv_cache_dtype: str):
60
+ self.hidden_size = config.hidden_size
61
+ self.num_heads = config.num_attention_heads
62
+ self.num_kv_heads = config.num_key_value_heads
63
+ self.rope_theta = config.rope_theta
64
+ self.rope_scaling = config.rope_scaling
65
+ self.original_max_position_embeddings = config.original_max_position_embeddings
66
+ self.max_position_embeddings = config.max_position_embeddings
67
+
68
+ self.head_dim_original = getattr(config, "head_dim",
69
+ self.hidden_size // self.num_heads)
70
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
71
+
72
+ sharding_size = mesh.shape["model"]
73
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
74
+ sharding_size)
75
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
76
+ sharding_size)
77
+
78
+ self.mesh = mesh
79
+
80
+ self.qkv_proj = nnx.Einsum(
81
+ "TD,DNH->TNH",
82
+ (self.hidden_size, self.num_heads + self.num_kv_heads * 2,
83
+ self.head_dim),
84
+ param_dtype=dtype,
85
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
86
+ rngs=rng,
87
+ )
88
+ self.o_proj = nnx.Einsum(
89
+ "TNH,NHD->TD",
90
+ (self.num_heads, self.head_dim, self.hidden_size),
91
+ param_dtype=dtype,
92
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
93
+ rngs=rng,
94
+ )
95
+
96
+ self._q_scale = 1.0
97
+ self._k_scale = 1.0
98
+ self._v_scale = 1.0
99
+ self.kv_cache_quantized_dtype = None
100
+ if kv_cache_dtype != "auto":
101
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
102
+ kv_cache_dtype)
103
+
104
+ def __call__(
105
+ self,
106
+ kv_cache: Optional[jax.Array],
107
+ x: jax.Array,
108
+ attention_metadata: AttentionMetadata,
109
+ ) -> Tuple[jax.Array, jax.Array]:
110
+ md = attention_metadata
111
+ # qkv: (T, N + K * 2, H)
112
+ qkv = self.qkv_proj(x)
113
+ q, k, v = jnp.split(
114
+ qkv, [self.num_heads, self.num_heads + self.num_kv_heads], axis=1)
115
+ if self.rope_scaling:
116
+ q = apply_longrope(q, md.input_positions, self.head_dim_original,
117
+ self.rope_scaling,
118
+ self.original_max_position_embeddings,
119
+ self.max_position_embeddings, self.rope_theta)
120
+ k = apply_longrope(k, md.input_positions, self.head_dim_original,
121
+ self.rope_scaling,
122
+ self.original_max_position_embeddings,
123
+ self.max_position_embeddings, self.rope_theta)
124
+ else:
125
+ q = apply_rope(q, md.input_positions, self.head_dim_original,
126
+ self.rope_theta, self.rope_scaling)
127
+ k = apply_rope(k, md.input_positions, self.head_dim_original,
128
+ self.rope_theta, self.rope_scaling)
129
+ # o: (T, N, H)
130
+ q_scale = k_scale = v_scale = None
131
+ if self.kv_cache_quantized_dtype:
132
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
133
+ # q_scale = self._q_scale
134
+ k_scale = self._k_scale
135
+ v_scale = self._v_scale
136
+ k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
137
+ k_scale, v_scale)
138
+ new_kv_cache, outputs = attention(
139
+ kv_cache,
140
+ q,
141
+ k,
142
+ v,
143
+ attention_metadata,
144
+ self.mesh,
145
+ self.head_dim_original,
146
+ q_scale=q_scale,
147
+ k_scale=k_scale,
148
+ v_scale=v_scale,
149
+ )
150
+ # (T, D)
151
+ o = self.o_proj(outputs)
152
+ return new_kv_cache, o
153
+
154
+
155
+ class Phi3DecoderLayer(nnx.Module):
156
+
157
+ def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
158
+ mesh: Mesh, kv_cache_dtype: str):
159
+ rms_norm_eps = config.rms_norm_eps
160
+ hidden_size = config.hidden_size
161
+
162
+ self.input_layernorm = nnx.RMSNorm(
163
+ hidden_size,
164
+ epsilon=rms_norm_eps,
165
+ param_dtype=dtype,
166
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
167
+ rngs=rng,
168
+ )
169
+ self.self_attn = Phi3Attention(config=config,
170
+ dtype=dtype,
171
+ rng=rng,
172
+ mesh=mesh,
173
+ kv_cache_dtype=kv_cache_dtype)
174
+ self.post_attention_layernorm = nnx.RMSNorm(
175
+ hidden_size,
176
+ epsilon=rms_norm_eps,
177
+ param_dtype=dtype,
178
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
179
+ rngs=rng,
180
+ )
181
+ self.mlp = Phi3MLP(
182
+ config=config,
183
+ dtype=dtype,
184
+ rng=rng,
185
+ )
186
+
187
+ def __call__(
188
+ self,
189
+ kv_cache: jax.Array,
190
+ x: jax.Array,
191
+ attention_metadata: AttentionMetadata,
192
+ ) -> Tuple[jax.Array, jax.Array]:
193
+ hidden_states = self.input_layernorm(x)
194
+ kv_cache, attn_output = self.self_attn(
195
+ kv_cache,
196
+ hidden_states,
197
+ attention_metadata,
198
+ )
199
+ attn_output += x
200
+
201
+ residual = attn_output
202
+ attn_output = self.post_attention_layernorm(attn_output)
203
+ outputs = self.mlp(attn_output)
204
+ outputs = residual + outputs
205
+ return kv_cache, outputs
206
+
207
+
208
+ class Phi3Model(nnx.Module):
209
+
210
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
211
+ mesh: Mesh) -> None:
212
+ model_config = vllm_config.model_config
213
+ hf_config = model_config.hf_config
214
+ vocab_size = model_config.get_vocab_size()
215
+ dtype = model_config.dtype
216
+ rms_norm_eps = hf_config.rms_norm_eps
217
+ hidden_size = hf_config.hidden_size
218
+
219
+ self.embed = nnx.Embed(
220
+ num_embeddings=vocab_size,
221
+ features=hidden_size,
222
+ param_dtype=dtype,
223
+ embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
224
+ rngs=rng,
225
+ )
226
+ self.layers = [
227
+ Phi3DecoderLayer(
228
+ config=hf_config,
229
+ dtype=dtype,
230
+ rng=rng,
231
+ mesh=mesh,
232
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
233
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype)
234
+ for _ in range(hf_config.num_hidden_layers)
235
+ ]
236
+ self.norm = nnx.RMSNorm(
237
+ hidden_size,
238
+ epsilon=rms_norm_eps,
239
+ param_dtype=dtype,
240
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
241
+ rngs=rng,
242
+ )
243
+ if model_config.hf_config.tie_word_embeddings:
244
+ self.lm_head = self.embed.embedding
245
+ else:
246
+ self.lm_head = nnx.Param(
247
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
248
+ sharding=(None, "model"),
249
+ )
250
+
251
+ def __call__(
252
+ self,
253
+ kv_caches: List[jax.Array],
254
+ input_ids: jax.Array,
255
+ attention_metadata: AttentionMetadata,
256
+ ) -> Tuple[List[jax.Array], jax.Array]:
257
+ x = self.embed(input_ids)
258
+ for i, layer in enumerate(self.layers):
259
+ kv_cache = kv_caches[i]
260
+ kv_cache, x = layer(
261
+ kv_cache,
262
+ x,
263
+ attention_metadata,
264
+ )
265
+ kv_caches[i] = kv_cache
266
+ x = self.norm(x)
267
+ return kv_caches, x
268
+
269
+
270
+ class Phi3ForCausalLM(nnx.Module):
271
+
272
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
273
+ mesh: Mesh) -> None:
274
+ self.vllm_config = vllm_config
275
+ self.rng = nnx.Rngs(rng_key)
276
+ self.mesh = mesh
277
+
278
+ self.model = Phi3Model(
279
+ vllm_config=vllm_config,
280
+ rng=self.rng,
281
+ mesh=mesh,
282
+ )
283
+
284
+ def __call__(
285
+ self,
286
+ kv_caches: List[jax.Array],
287
+ input_ids: jax.Array,
288
+ attention_metadata: AttentionMetadata,
289
+ *args,
290
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
291
+ kv_caches, x = self.model(
292
+ kv_caches,
293
+ input_ids,
294
+ attention_metadata,
295
+ )
296
+ return kv_caches, x, []
297
+
298
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
299
+ if self.vllm_config.model_config.hf_config.tie_word_embeddings:
300
+ logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
301
+ else:
302
+ logits = jnp.dot(hidden_states, self.model.lm_head.value)
303
+ return logits
304
+
305
+ def get_metadata_map(self) -> MetadataMap:
306
+ sharding_size = self.mesh.shape["model"]
307
+
308
+ model_config = self.vllm_config.model_config
309
+ hf_config = model_config.hf_config
310
+
311
+ num_heads = hf_config.num_attention_heads
312
+ num_kv_heads = hf_config.num_key_value_heads
313
+ qkv_heads = num_heads + num_kv_heads * 2
314
+ hidden_size = model_config.get_hidden_size()
315
+
316
+ # Pad head_dim for kernel performance.
317
+ head_dim_original = model_config.get_head_size()
318
+
319
+ # Key: path to a HF layer weight
320
+ # Value: path to a nnx layer weight
321
+ name_map = {
322
+ "model.embed_tokens": "model.embed.embedding",
323
+ "model.layers.*.input_layernorm":
324
+ "model.layers.*.input_layernorm.scale",
325
+ "model.layers.*.mlp.down_proj":
326
+ "model.layers.*.mlp.down_proj.kernel",
327
+ "model.layers.*.mlp.gate_up_proj":
328
+ "model.layers.*.mlp.gate_up_proj.kernel",
329
+ "model.layers.*.post_attention_layernorm":
330
+ "model.layers.*.post_attention_layernorm.scale",
331
+ "model.layers.*.self_attn.qkv_proj":
332
+ "model.layers.*.self_attn.qkv_proj.kernel",
333
+ "model.layers.*.self_attn.o_proj":
334
+ "model.layers.*.self_attn.o_proj.kernel",
335
+ "model.norm": "model.norm.scale",
336
+ }
337
+ if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
338
+ name_map.update({
339
+ "lm_head": "model.lm_head",
340
+ })
341
+
342
+ reshape_keys: dict[str, tuple[int, ...]] = {
343
+ "qkv_proj": (qkv_heads, head_dim_original, hidden_size),
344
+ "o_proj": (hidden_size, num_heads, head_dim_original),
345
+ }
346
+ transpose_keys: dict[str, tuple[int, ...]] = {
347
+ "lm_head": (1, 0),
348
+ "gate_up_proj": (1, 0),
349
+ "down_proj": (1, 0),
350
+ "qkv_proj": (2, 0, 1),
351
+ "o_proj": (1, 2, 0),
352
+ }
353
+
354
+ # key: (padding_dim, padding_size)
355
+ pad_keys: dict[str, tuple[int, ...]] = {
356
+ "qkv_proj": (1, sharding_size // num_heads),
357
+ "o_proj": (0, sharding_size // num_heads),
358
+ }
359
+
360
+ return MetadataMap(name_map=name_map,
361
+ reshape_map=reshape_keys,
362
+ bias_reshape_map={},
363
+ transpose_map=transpose_keys,
364
+ pad_map=pad_keys,
365
+ bias_pad_map={})
366
+
367
+ def load_weights(self, rng_key: jax.Array):
368
+ # NOTE: Since we are using nnx.eval_shape to init the model,
369
+ # we have to pass dynamic arrays here for __call__'s usage.
370
+ self.rng = nnx.Rngs(rng_key)
371
+
372
+ metadata_map = self.get_metadata_map()
373
+ load_hf_weights(vllm_config=self.vllm_config,
374
+ model=self,
375
+ metadata_map=metadata_map,
376
+ mesh=self.mesh)