tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import os
1
2
  import re
2
3
  from dataclasses import dataclass
3
4
  from typing import List, Optional, Tuple
@@ -13,6 +14,7 @@ from torchax.ops.mappings import j2t_dtype
13
14
  from vllm.config import VllmConfig
14
15
 
15
16
  from tpu_inference import utils
17
+ from tpu_inference.layers.common.sharding import ShardingAxisName
16
18
  from tpu_inference.layers.jax.attention.attention import AttentionMetadata
17
19
  from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
18
20
  from tpu_inference.layers.jax.constants import KVCacheType
@@ -69,6 +71,7 @@ class DeepSeekV3(nnx.Module):
69
71
  hidden_act: str = "silu"
70
72
  rms_norm_eps: float = 1e-06
71
73
  first_k_dense_replace: int = 3 # replace the first few MOE layers to dense layer.
74
+ self.use_mla_kernel: bool = self.vllm_config.model_config.use_mla
72
75
 
73
76
  num_shared_experts = 1
74
77
  rope_theta = 10000
@@ -114,19 +117,30 @@ class DeepSeekV3(nnx.Module):
114
117
  qk_rope_head_dim=qk_rope_head_dim,
115
118
  v_head_dim=v_head_dim,
116
119
  num_local_experts=num_local_experts,
117
- model_dtype=dtype)
120
+ model_dtype=dtype,
121
+ use_mla_kernel=self.use_mla_kernel)
118
122
 
119
123
  self.embedder = Embedder(vocab_size=vocab_size,
120
124
  hidden_size=hidden_size,
121
125
  dtype=dtype,
122
126
  rngs=self.rng,
123
- vd_sharding=(('data', 'expert', 'model'),
127
+ vd_sharding=(ShardingAxisName.MLP_TENSOR,
124
128
  None),
125
129
  random_init=self.random_init)
126
130
 
127
131
  self.layers = []
128
132
 
129
133
  def _create_mla() -> MLA:
134
+ if self.use_mla_kernel:
135
+ query_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
136
+ keyvalue_skh_spec = P(ShardingAxisName.MLP_TENSOR, None)
137
+ attn_o_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
138
+
139
+ else:
140
+ query_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
141
+ keyvalue_skh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
142
+ attn_o_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
143
+
130
144
  return MLA(
131
145
  rope_theta=rope_theta,
132
146
  rope_scaling=rope_scaling,
@@ -137,10 +151,12 @@ class DeepSeekV3(nnx.Module):
137
151
  rms_norm_eps=rms_norm_eps,
138
152
  v_head_dim=v_head_dim,
139
153
  mesh=self.mesh,
154
+ use_mla_kernel=self.use_mla_kernel,
140
155
  random_init=self.random_init,
141
156
  hidden_size=hidden_size,
142
157
  num_attention_heads=num_attention_heads,
143
- num_key_value_heads=num_key_value_heads,
158
+ num_key_value_heads=1
159
+ if self.use_mla_kernel else num_key_value_heads,
144
160
  head_dim=v_head_dim, # MLA uses v_head_dim as head_dim
145
161
  dtype=dtype,
146
162
  # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
@@ -148,14 +164,14 @@ class DeepSeekV3(nnx.Module):
148
164
  rngs=self.rng,
149
165
  activation_attention_td=(None, None),
150
166
  activation_q_td=(None, None),
151
- query_tnh=P(None, 'model', None),
152
- keyvalue_skh=P(None, 'model', None),
167
+ query_tnh=query_tnh_spec,
168
+ keyvalue_skh=keyvalue_skh_spec,
153
169
  activation_attention_out_td=(None, None),
154
- attn_o_tnh=P(None, 'model', None),
155
- q_da_sharding=(None, 'model'),
156
- anh_sharding=(None, 'model', None),
157
- kv_da_sharding=(None, 'model'),
158
- nhd_sharding=('model', None, None))
170
+ attn_o_tnh=attn_o_tnh_spec,
171
+ q_da_sharding=(None, ShardingAxisName.VOCAB),
172
+ anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
173
+ kv_da_sharding=(None, ShardingAxisName.VOCAB),
174
+ nhd_sharding=(ShardingAxisName.MLP_TENSOR, None, None))
159
175
 
160
176
  for i in range(first_k_dense_replace):
161
177
  block = TransformerBlock(
@@ -176,14 +192,15 @@ class DeepSeekV3(nnx.Module):
176
192
  rngs=self.rng,
177
193
  ),
178
194
  attn=_create_mla(),
179
- custom_module=DenseFFW(dtype=dtype,
180
- hidden_act=hidden_act,
181
- hidden_size=hidden_size,
182
- intermediate_size=ffw_intermediate_size,
183
- rngs=self.rng,
184
- df_sharding=(None, ('model', 'expert')),
185
- fd_sharding=(('model', 'expert'), None),
186
- random_init=self.random_init))
195
+ custom_module=DenseFFW(
196
+ dtype=dtype,
197
+ hidden_act=hidden_act,
198
+ hidden_size=hidden_size,
199
+ intermediate_size=ffw_intermediate_size,
200
+ rngs=self.rng,
201
+ df_sharding=(None, ShardingAxisName.MLP_TENSOR),
202
+ fd_sharding=(ShardingAxisName.MLP_TENSOR, None),
203
+ random_init=self.random_init))
187
204
 
188
205
  self.layers.append(block)
189
206
 
@@ -200,9 +217,9 @@ class DeepSeekV3(nnx.Module):
200
217
  rngs=self.rng,
201
218
  routed_scaling_factor=2.5,
202
219
  dtype=dtype,
203
- activation_ffw_td=('data', None),
204
- ed_sharding=('model', None),
205
- e_sharding=('model', ))
220
+ activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
221
+ ed_sharding=(ShardingAxisName.MLP_TENSOR, None),
222
+ e_sharding=(ShardingAxisName.MLP_TENSOR, ))
206
223
  if self.sparse_matmul:
207
224
  # TODO: orginize the SparseMoE and DenseMoE better given they share most interfaces
208
225
  custom_module = SparseMoE(
@@ -216,10 +233,10 @@ class DeepSeekV3(nnx.Module):
216
233
  hidden_act=hidden_act,
217
234
  rngs=self.rng,
218
235
  random_init=self.random_init,
219
- activation_ffw_td=('data', None),
220
- activation_ffw_ted=('data', None, None),
221
- edf_sharding=('model', None, None),
222
- efd_sharding=('model', None, None),
236
+ activation_ffw_td=(ShardingAxisName.MLP_TENSOR, None),
237
+ activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
238
+ edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
239
+ efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
223
240
  quantized_dtype=self.weight_loader.quant_dtype
224
241
  if self.weight_loader.is_model_quantized else None,
225
242
  router=router) if is_moe_layer else DenseFFW(
@@ -229,8 +246,8 @@ class DeepSeekV3(nnx.Module):
229
246
  intermediate_size=ffw_intermediate_size,
230
247
  rngs=self.rng,
231
248
  random_init=self.random_init,
232
- df_sharding=(None, ('model', 'expert')),
233
- fd_sharding=(('model', 'expert'), None))
249
+ df_sharding=(None, ShardingAxisName.MLP_TENSOR),
250
+ fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
234
251
  else:
235
252
  custom_module = MoE(
236
253
  dtype=dtype,
@@ -241,10 +258,10 @@ class DeepSeekV3(nnx.Module):
241
258
  hidden_act=hidden_act,
242
259
  rngs=self.rng,
243
260
  random_init=self.random_init,
244
- activation_ffw_td=('data', None),
245
- activation_ffw_ted=('data', None, None),
246
- edf_sharding=('model', None, None),
247
- efd_sharding=('model', None, None),
261
+ activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
262
+ activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
263
+ edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
264
+ efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
248
265
  router=router) if is_moe_layer else DenseFFW(
249
266
  dtype=dtype,
250
267
  hidden_act=hidden_act,
@@ -252,18 +269,18 @@ class DeepSeekV3(nnx.Module):
252
269
  intermediate_size=ffw_intermediate_size,
253
270
  rngs=self.rng,
254
271
  random_init=self.random_init,
255
- df_sharding=(None, ('model', 'expert')),
256
- fd_sharding=(('model', 'expert'), None))
257
-
258
- shared_experts = DenseFFW(dtype=dtype,
259
- hidden_act=hidden_act,
260
- hidden_size=hidden_size,
261
- intermediate_size=num_shared_experts *
262
- moe_intermediate_size,
263
- rngs=self.rng,
264
- random_init=self.random_init,
265
- df_sharding=(None, ('model', 'expert')),
266
- fd_sharding=(('model', 'expert'), None))
272
+ df_sharding=(None, ShardingAxisName.MLP_TENSOR),
273
+ fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
274
+
275
+ shared_experts = DenseFFW(
276
+ dtype=dtype,
277
+ hidden_act=hidden_act,
278
+ hidden_size=hidden_size,
279
+ intermediate_size=num_shared_experts * moe_intermediate_size,
280
+ rngs=self.rng,
281
+ random_init=self.random_init,
282
+ df_sharding=(None, ShardingAxisName.MLP_TENSOR),
283
+ fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
267
284
 
268
285
  pre_attention_norm = RMSNorm(
269
286
  dims=hidden_size,
@@ -304,10 +321,28 @@ class DeepSeekV3(nnx.Module):
304
321
  hidden_size=hidden_size,
305
322
  dtype=dtype,
306
323
  rngs=self.rng,
307
- vd_sharding=(('data', 'expert', 'model'), None),
308
- dv_sharding=(None, ('data', 'expert', 'model')),
324
+ vd_sharding=(ShardingAxisName.MLP_TENSOR, None),
325
+ dv_sharding=(None, ShardingAxisName.MLP_TENSOR),
309
326
  random_init=self.random_init)
310
327
 
328
+ if os.environ.get("VLLM_LOGGING_LEVEL", "").upper() == "DEBUG":
329
+ self._print_model_architecture()
330
+
331
+ def _print_model_architecture(self):
332
+ num_display_layers = 5
333
+
334
+ logger.debug("### Embedding ###")
335
+ nnx.display(self.embedder)
336
+
337
+ logger.debug(f"\n### First {num_display_layers} Layers ###")
338
+ # Loop through the slice and display each layer
339
+ for i, layer in enumerate(self.layers[:num_display_layers]):
340
+ logger.debug(f"\n--- Layer {i} ---")
341
+ nnx.display(layer)
342
+
343
+ logger.debug("\n### LM Head ###")
344
+ nnx.display(self.lm_head)
345
+
311
346
  # For compatibility with flax.
312
347
  def apply(self, variables, *args, **kwargs):
313
348
  return self.__call__(*args, **kwargs)
@@ -352,10 +387,19 @@ class DeepSeekV3(nnx.Module):
352
387
  @dataclass
353
388
  class DeepSeekV3WeightLoader:
354
389
 
355
- def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size,
356
- q_lora_rank, kv_lora_rank, attn_heads, qk_nope_head_dim,
357
- qk_rope_head_dim, v_head_dim, num_local_experts, model_dtype):
358
-
390
+ def __init__(self,
391
+ vllm_config: VllmConfig,
392
+ num_layers,
393
+ hidden_size,
394
+ q_lora_rank,
395
+ kv_lora_rank,
396
+ attn_heads,
397
+ qk_nope_head_dim,
398
+ qk_rope_head_dim,
399
+ v_head_dim,
400
+ num_local_experts,
401
+ model_dtype,
402
+ use_mla_kernel=False):
359
403
  self.num_layers = num_layers
360
404
  self.names_and_weights_generator = model_weights_generator(
361
405
  model_name_or_path=vllm_config.model_config.model,
@@ -364,7 +408,12 @@ class DeepSeekV3WeightLoader:
364
408
  self.is_verbose = vllm_config.additional_config.get(
365
409
  "is_verbose", None) is not None
366
410
  self.num_routed_experts = num_local_experts
411
+ self.attn_heads = attn_heads
412
+ self.qk_nope_head_dim = qk_nope_head_dim
413
+ self.v_head_dim = v_head_dim
414
+ self.kv_lora_rank = kv_lora_rank
367
415
  self.model_dtype = model_dtype
416
+ self.use_mla_kernel = use_mla_kernel
368
417
 
369
418
  self._transpose_map = {
370
419
  # dense mlp
@@ -376,6 +425,8 @@ class DeepSeekV3WeightLoader:
376
425
  r"q_b_proj": (2, 0, 1),
377
426
  r"kv_a_proj_with_mqa": (1, 0),
378
427
  r"kv_b_proj": (2, 0, 1),
428
+ r"k_b_proj": (2, 0, 1), # used for MLA kernel
429
+ r"v_b_proj": (2, 0, 1), # used for MLA kernel
379
430
  r"o_proj": (1, 2, 0),
380
431
  # moe
381
432
  r"mlp\.gate\.weight": (1, 0),
@@ -393,6 +444,8 @@ class DeepSeekV3WeightLoader:
393
444
  (attn_heads, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank),
394
445
  "kv_b_proj":
395
446
  (attn_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank),
447
+ "k_b_proj": (attn_heads, qk_nope_head_dim, kv_lora_rank),
448
+ "v_b_proj": (attn_heads, v_head_dim, kv_lora_rank),
396
449
  "o_proj": (hidden_size, attn_heads, v_head_dim)
397
450
  }
398
451
 
@@ -452,6 +505,13 @@ class DeepSeekV3WeightLoader:
452
505
  "model.layers.*.mlp.shared_experts.up_proj.weight":
453
506
  "layers.*.shared_experts.kernel_up_proj_DF",
454
507
  }
508
+ if self.use_mla_kernel:
509
+ self._loaded_to_standardized_keys.update({
510
+ "model.layers.*.self_attn.k_b_proj.weight":
511
+ "layers.*.attn.kernel_k_up_proj_ANH",
512
+ "model.layers.*.self_attn.v_b_proj.weight":
513
+ "layers.*.attn.kernel_v_up_proj_ANH",
514
+ })
455
515
 
456
516
  # TODO (jacobplatin): we shouldn't hard-code this, but the logic to obtain the true quantized dtype
457
517
  # is non-trivial and the default checkpoints all use this dtype
@@ -487,6 +547,15 @@ class DeepSeekV3WeightLoader:
487
547
  "kv_b_proj": (attn_heads, (qk_nope_head_dim + v_head_dim) //
488
548
  self.quantization_block_size_n,
489
549
  kv_lora_rank // self.quantization_block_size_n),
550
+ # used for MLA kernel
551
+ "k_b_proj":
552
+ (attn_heads,
553
+ qk_nope_head_dim // self.quantization_block_size_n,
554
+ kv_lora_rank // self.quantization_block_size_n),
555
+ # used for MLA kernel
556
+ "v_b_proj":
557
+ (attn_heads, v_head_dim // self.quantization_block_size_n,
558
+ kv_lora_rank // self.quantization_block_size_n),
490
559
  "o_proj":
491
560
  (hidden_size // self.quantization_block_size_n, attn_heads,
492
561
  v_head_dim // self.quantization_block_size_n),
@@ -802,21 +871,73 @@ class DeepSeekV3WeightLoader:
802
871
  f"Cumulative local memory: {cumulative_local_memory} GB"
803
872
  )
804
873
  else:
805
- weight_bytes, weight_shards = self._load_individual_weight(
806
- loaded_name,
807
- loaded_weight,
808
- model_params,
809
- model_for_loading.mesh,
810
- scale=scale)
811
- if self.is_verbose:
812
- cumulative_global_memory += weight_bytes
813
- cumulative_local_memory += weight_shards
814
- logger.info(
815
- f"Cumulative global memory: {cumulative_global_memory} GB"
816
- )
817
- logger.info(
818
- f"Cumulative local memory: {cumulative_local_memory} GB"
819
- )
874
+ if self.use_mla_kernel and "kv_b_proj" in loaded_name:
875
+ # loaded_weight shape: (num_heads * (d_k + d_v), kv_lora_rank)
876
+ # scale shape: (num_heads * (d_k + d_v) / block_n, kv_lora_rank / block_k)
877
+ # Reshape to (num_heads, (d_k + d_v), kv_lora_rank) and split
878
+ weight_reshaped = loaded_weight.view(
879
+ self.attn_heads,
880
+ self.qk_nope_head_dim + self.v_head_dim,
881
+ self.kv_lora_rank)
882
+ k_weight = weight_reshaped[:, :self.
883
+ qk_nope_head_dim, :].reshape(
884
+ -1, self.kv_lora_rank)
885
+ v_weight = weight_reshaped[:, self.
886
+ qk_nope_head_dim:, :].reshape(
887
+ -1, self.kv_lora_rank)
888
+
889
+ loaded_weights_list = [k_weight, v_weight]
890
+ loaded_names = [
891
+ loaded_name.replace("kv_b_proj", "k_b_proj"),
892
+ loaded_name.replace("kv_b_proj", "v_b_proj")
893
+ ]
894
+
895
+ scales_list = [None, None]
896
+ if scale is not None:
897
+ bn = self.quantization_block_size_n
898
+ bk = self.quantization_block_size_k
899
+ scale_reshaped = scale.view(
900
+ self.attn_heads,
901
+ (self.qk_nope_head_dim + self.v_head_dim) //
902
+ bn, self.kv_lora_rank // bk)
903
+
904
+ k_scale = scale_reshaped[:, :self.
905
+ qk_nope_head_dim //
906
+ bn, :].reshape(
907
+ -1,
908
+ self.kv_lora_rank //
909
+ bk)
910
+ v_scale = scale_reshaped[:,
911
+ self.qk_nope_head_dim //
912
+ bn:, :].reshape(
913
+ -1,
914
+ self.kv_lora_rank //
915
+ bk)
916
+ scales_list = [k_scale, v_scale]
917
+
918
+ else:
919
+ loaded_weights_list = [loaded_weight]
920
+ loaded_names = [loaded_name]
921
+ scales_list = [scale]
922
+
923
+ for loaded_name, loaded_weight, scale in zip(
924
+ loaded_names, loaded_weights_list, scales_list):
925
+
926
+ weight_bytes, weight_shards = self._load_individual_weight(
927
+ loaded_name,
928
+ loaded_weight,
929
+ model_params,
930
+ model_for_loading.mesh,
931
+ scale=scale)
932
+ if self.is_verbose:
933
+ cumulative_global_memory += weight_bytes
934
+ cumulative_local_memory += weight_shards
935
+ logger.info(
936
+ f"Cumulative global memory: {cumulative_global_memory} GB"
937
+ )
938
+ logger.info(
939
+ f"Cumulative local memory: {cumulative_local_memory} GB"
940
+ )
820
941
 
821
942
  del mlp_experts_gate_proj_weights
822
943
  del mlp_experts_up_proj_weights
@@ -102,9 +102,9 @@ class GptOss(nnx.Module):
102
102
  rope_ntk_beta=rope_ntk_beta,
103
103
  rngs=self.rng,
104
104
  random_init=self.random_init,
105
- query_tnh=P(None, 'model', None),
106
- keyvalue_skh=P(None, 'model', None),
107
- attn_o_tnh=P(None, 'model', None),
105
+ query_tnh=P("data", 'model', None),
106
+ keyvalue_skh=P("data", 'model', None),
107
+ attn_o_tnh=P("data", 'model', None),
108
108
  dnh_sharding=P(None, 'model', None),
109
109
  dkh_sharding=P(None, 'model', None),
110
110
  nhd_sharding=P('model', None, None),
@@ -368,7 +368,8 @@ class LlamaForCausalLM(nnx.Module):
368
368
  "lm_head": "model.lm_head",
369
369
  })
370
370
 
371
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
371
+ metadata_map = get_default_maps(self.vllm_config.model_config,
372
+ self.mesh, mappings)
372
373
  load_hf_weights(vllm_config=self.vllm_config,
373
374
  model=self,
374
375
  metadata_map=metadata_map,
@@ -194,13 +194,12 @@ class Eagle3LlamaModel(nnx.Module):
194
194
 
195
195
  def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
196
196
  metadata_map: MetadataMap):
197
- model_config = vllm_config.model_config
197
+ model_config = vllm_config.speculative_config.draft_model_config
198
198
  hf_config = model_config.hf_config
199
199
 
200
200
  num_heads = hf_config.num_attention_heads
201
201
  num_kv_heads = hf_config.num_key_value_heads
202
- hidden_size = model_config.get_hidden_size()
203
-
202
+ hidden_size = hf_config.hidden_size
204
203
  head_dim_original = model_config.get_head_size()
205
204
 
206
205
  metadata_map.reshape_map.update({
@@ -305,6 +304,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
305
304
  "fc": "model.fc.kernel",
306
305
  "lm_head": "lm_head.kernel",
307
306
  "d2t": "draft_id_to_target_id",
307
+ "embed_tokens":
308
+ "model.embed_tokens.embedding", # Some checkpoints need this
308
309
  }
309
310
 
310
311
  # Define keys to keep in original dtype (e.g., float32 for stability)
@@ -312,7 +313,9 @@ class EagleLlama3ForCausalLM(nnx.Module):
312
313
  r".*d2t.*",
313
314
  ]
314
315
 
315
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
316
+ metadata_map = get_default_maps(
317
+ self.vllm_config.speculative_config.draft_model_config, self.mesh,
318
+ mappings)
316
319
 
317
320
  update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
318
321
 
@@ -324,7 +327,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
324
327
  is_draft_model=True,
325
328
  keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
326
329
 
327
- # If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
330
+ # If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
328
331
  if isinstance(self.model.embed_tokens.embedding.value,
329
332
  jax.ShapeDtypeStruct):
330
333
  self.model.embed_tokens.embedding.value = jnp.zeros(