tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from typing import Any, Optional
|
|
3
17
|
|
|
@@ -15,9 +29,10 @@ from vllm.utils.func_utils import supports_kw
|
|
|
15
29
|
from tpu_inference import envs
|
|
16
30
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
17
31
|
from tpu_inference.logger import init_logger
|
|
18
|
-
from tpu_inference.models.jax.utils.
|
|
32
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import (
|
|
19
33
|
apply_qwix_on_abstract_model, apply_qwix_quantization,
|
|
20
|
-
load_random_weights_into_qwix_abstract_model
|
|
34
|
+
load_random_weights_into_qwix_abstract_model,
|
|
35
|
+
update_vllm_config_for_qwix_quantization)
|
|
21
36
|
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
|
|
22
37
|
|
|
23
38
|
logger = init_logger(__name__)
|
|
@@ -218,6 +233,10 @@ def get_flax_model(
|
|
|
218
233
|
model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
|
|
219
234
|
vllm_config.model_config.dtype = model_dtype
|
|
220
235
|
|
|
236
|
+
# Only perform qwix quantization if it is jax model.
|
|
237
|
+
if vllm_config.model_config:
|
|
238
|
+
update_vllm_config_for_qwix_quantization(vllm_config)
|
|
239
|
+
|
|
221
240
|
if is_draft_model:
|
|
222
241
|
model_class = _get_model_architecture(
|
|
223
242
|
vllm_config.speculative_config.draft_model_config.hf_config)
|
|
@@ -269,10 +288,9 @@ def get_flax_model(
|
|
|
269
288
|
|
|
270
289
|
# Multi-modal support only
|
|
271
290
|
# This function calculates the image token's embeddings by VIT
|
|
272
|
-
def
|
|
273
|
-
**kwargs):
|
|
291
|
+
def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
|
|
274
292
|
model = nnx.merge(graphdef, state)
|
|
275
|
-
return model.
|
|
293
|
+
return model.embed_multimodal(image_grid_thw, **kwargs)
|
|
276
294
|
|
|
277
295
|
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
278
296
|
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
@@ -280,9 +298,9 @@ def get_flax_model(
|
|
|
280
298
|
jax.jit,
|
|
281
299
|
out_shardings=(embed_sharding),
|
|
282
300
|
)
|
|
283
|
-
def
|
|
301
|
+
def run_embed_input_ids(graphdef, state, *args, **kwargs):
|
|
284
302
|
model = nnx.merge(graphdef, state)
|
|
285
|
-
return model.
|
|
303
|
+
return model.embed_input_ids(*args, **kwargs)
|
|
286
304
|
|
|
287
305
|
# For models that want to work with EAGLE-3 speculative decoding
|
|
288
306
|
@functools.partial(
|
|
@@ -298,10 +316,8 @@ def get_flax_model(
|
|
|
298
316
|
None)
|
|
299
317
|
model_fn = functools.partial(run_model, graphdef)
|
|
300
318
|
compute_logits_fn = functools.partial(run_compute_logits, graphdef)
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
|
|
304
|
-
graphdef)
|
|
319
|
+
embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
|
|
320
|
+
embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
|
|
305
321
|
lora_manager, model = None, None
|
|
306
322
|
combine_hidden_states_fn = functools.partial(combine_hidden_states,
|
|
307
323
|
graphdef)
|
|
@@ -312,8 +328,8 @@ def get_flax_model(
|
|
|
312
328
|
|
|
313
329
|
multimodal_fns = {
|
|
314
330
|
"precompile_vision_encoder_fn": precompile_vision_encoder_fn,
|
|
315
|
-
"
|
|
316
|
-
"
|
|
331
|
+
"embed_multimodal_fn": embed_multimodal_fn,
|
|
332
|
+
"embed_input_ids_fn": embed_input_ids_fn,
|
|
317
333
|
"get_mrope_input_positions_fn": get_mrope_input_positions_fn,
|
|
318
334
|
}
|
|
319
335
|
|
|
@@ -365,6 +381,11 @@ def get_model(
|
|
|
365
381
|
|
|
366
382
|
match impl:
|
|
367
383
|
case "flax_nnx":
|
|
384
|
+
if vllm_config.parallel_config.pipeline_parallel_size > 1:
|
|
385
|
+
logger.warning(
|
|
386
|
+
"PP is not fully supported on Jax flax_nnx models yet, fallback to vllm models."
|
|
387
|
+
)
|
|
388
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
368
389
|
try:
|
|
369
390
|
# Try to load the flax model first
|
|
370
391
|
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
|
|
@@ -466,14 +487,14 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
466
487
|
)
|
|
467
488
|
|
|
468
489
|
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
|
|
469
|
-
def
|
|
490
|
+
def unimplemented_embed_input_ids(
|
|
470
491
|
self,
|
|
471
492
|
input_ids: "torch.Tensor",
|
|
472
493
|
positions: "torch.Tensor",
|
|
473
494
|
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
474
495
|
) -> "torch.Tensor":
|
|
475
496
|
raise NotImplementedError(
|
|
476
|
-
"This is a JAX model and does not implement the PyTorch
|
|
497
|
+
"This is a JAX model and does not implement the PyTorch embed_input_ids method."
|
|
477
498
|
)
|
|
478
499
|
|
|
479
500
|
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
@@ -489,7 +510,7 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
489
510
|
{
|
|
490
511
|
"__init__": wrapper_init,
|
|
491
512
|
"forward": unimplemented_forward,
|
|
492
|
-
"
|
|
513
|
+
"embed_input_ids": unimplemented_embed_input_ids,
|
|
493
514
|
# Prevent vLLM from trying to load weights into this dummy class.
|
|
494
515
|
"load_weights": lambda self, *args, **kwargs: None,
|
|
495
516
|
})
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import os
|
|
2
16
|
import re
|
|
3
17
|
from dataclasses import dataclass
|
|
@@ -14,6 +28,7 @@ from torchax.ops.mappings import j2t_dtype
|
|
|
14
28
|
from vllm.config import VllmConfig
|
|
15
29
|
|
|
16
30
|
from tpu_inference import utils
|
|
31
|
+
from tpu_inference.layers.common.quantization import u8_unpack_e2m1
|
|
17
32
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
18
33
|
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
19
34
|
from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
|
|
@@ -25,10 +40,8 @@ from tpu_inference.layers.jax.moe.moe import MoE
|
|
|
25
40
|
from tpu_inference.layers.jax.transformer_block import (
|
|
26
41
|
SharedExpertsTransformerBlock, TransformerBlock)
|
|
27
42
|
from tpu_inference.logger import init_logger
|
|
28
|
-
from tpu_inference.models.jax.utils.quantization.quantization_utils import \
|
|
29
|
-
get_quant_dtype_from_qwix_config
|
|
30
43
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
31
|
-
get_param, model_weights_generator, print_param_info
|
|
44
|
+
get_param, model_weights_generator, print_param_info)
|
|
32
45
|
|
|
33
46
|
logger = init_logger(__name__)
|
|
34
47
|
|
|
@@ -73,6 +86,8 @@ class DeepSeekV3(nnx.Module):
|
|
|
73
86
|
first_k_dense_replace: int = 3 # replace the first few MOE layers to dense layer.
|
|
74
87
|
self.use_mla_kernel: bool = self.vllm_config.model_config.use_mla
|
|
75
88
|
|
|
89
|
+
logger.info(f"Is using MLA kernel in DeepSeek: {self.use_mla_kernel}")
|
|
90
|
+
|
|
76
91
|
num_shared_experts = 1
|
|
77
92
|
rope_theta = 10000
|
|
78
93
|
rope_scaling = {
|
|
@@ -169,9 +184,10 @@ class DeepSeekV3(nnx.Module):
|
|
|
169
184
|
activation_attention_out_td=(None, None),
|
|
170
185
|
attn_o_tnh=attn_o_tnh_spec,
|
|
171
186
|
q_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
187
|
+
ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
172
188
|
anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
|
|
173
189
|
kv_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
174
|
-
|
|
190
|
+
rd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
175
191
|
|
|
176
192
|
for i in range(first_k_dense_replace):
|
|
177
193
|
block = TransformerBlock(
|
|
@@ -422,12 +438,12 @@ class DeepSeekV3WeightLoader:
|
|
|
422
438
|
r"mlp\.up_proj": (1, 0),
|
|
423
439
|
# mla
|
|
424
440
|
r"q_a_proj": (1, 0),
|
|
425
|
-
r"q_b_proj": (
|
|
441
|
+
r"q_b_proj": (1, 0),
|
|
426
442
|
r"kv_a_proj_with_mqa": (1, 0),
|
|
427
|
-
r"kv_b_proj": (
|
|
443
|
+
r"kv_b_proj": (1, 0),
|
|
428
444
|
r"k_b_proj": (2, 0, 1), # used for MLA kernel
|
|
429
445
|
r"v_b_proj": (2, 0, 1), # used for MLA kernel
|
|
430
|
-
r"o_proj": (1,
|
|
446
|
+
r"o_proj": (1, 0),
|
|
431
447
|
# moe
|
|
432
448
|
r"mlp\.gate\.weight": (1, 0),
|
|
433
449
|
r"mlp\.experts\.\d+\.gate_proj": (0, 2, 1),
|
|
@@ -439,15 +455,6 @@ class DeepSeekV3WeightLoader:
|
|
|
439
455
|
# lm_head
|
|
440
456
|
r"lm_head\.weight": (1, 0)
|
|
441
457
|
}
|
|
442
|
-
self._weight_shape_map = {
|
|
443
|
-
"q_b_proj":
|
|
444
|
-
(attn_heads, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank),
|
|
445
|
-
"kv_b_proj":
|
|
446
|
-
(attn_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank),
|
|
447
|
-
"k_b_proj": (attn_heads, qk_nope_head_dim, kv_lora_rank),
|
|
448
|
-
"v_b_proj": (attn_heads, v_head_dim, kv_lora_rank),
|
|
449
|
-
"o_proj": (hidden_size, attn_heads, v_head_dim)
|
|
450
|
-
}
|
|
451
458
|
|
|
452
459
|
# Set the mappings from loaded parameter keys to standardized names.
|
|
453
460
|
self._loaded_to_standardized_keys = {
|
|
@@ -472,13 +479,13 @@ class DeepSeekV3WeightLoader:
|
|
|
472
479
|
"model.layers.*.self_attn.q_a_proj.weight":
|
|
473
480
|
"layers.*.attn.kernel_q_down_proj_DA",
|
|
474
481
|
"model.layers.*.self_attn.q_b_proj.weight":
|
|
475
|
-
"layers.*.attn.
|
|
482
|
+
"layers.*.attn.kernel_q_up_proj_AP",
|
|
476
483
|
"model.layers.*.self_attn.kv_a_proj_with_mqa.weight":
|
|
477
484
|
"layers.*.attn.kernel_kv_down_proj_DA",
|
|
478
485
|
"model.layers.*.self_attn.kv_b_proj.weight":
|
|
479
|
-
"layers.*.attn.
|
|
486
|
+
"layers.*.attn.kernel_kv_up_proj_AL",
|
|
480
487
|
"model.layers.*.self_attn.o_proj.weight":
|
|
481
|
-
"layers.*.attn.
|
|
488
|
+
"layers.*.attn.kernel_o_proj_RD",
|
|
482
489
|
# Dense ffw
|
|
483
490
|
"model.layers.*.mlp.gate_proj.weight":
|
|
484
491
|
"layers.*.custom_module.kernel_gating_DF",
|
|
@@ -512,66 +519,43 @@ class DeepSeekV3WeightLoader:
|
|
|
512
519
|
"model.layers.*.self_attn.v_b_proj.weight":
|
|
513
520
|
"layers.*.attn.kernel_v_up_proj_ANH",
|
|
514
521
|
})
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
# is non-trivial and the default checkpoints all use this dtype
|
|
518
|
-
self.quant_dtype = jnp.float8_e4m3fn
|
|
522
|
+
# TODO (jacobplatin): we should not be hard-coding these
|
|
523
|
+
self.scale_dtype, self.quant_dtype = jnp.bfloat16, jnp.float8_e4m3fn
|
|
519
524
|
|
|
520
525
|
self.is_model_quantized = not vllm_config.additional_config.get(
|
|
521
526
|
"skip_quantization", False)
|
|
522
|
-
if self.is_model_quantized:
|
|
523
|
-
# TODO (jacobplatin): expand support eventually
|
|
524
|
-
quantization_type = vllm_config.model_config.hf_config.quantization_config[
|
|
525
|
-
"quant_method"]
|
|
526
|
-
assert quantization_type == "fp8", "DeepSeek only supports the fp8 quantization method for now"
|
|
527
|
-
self.scale_dtype, self.quant_dtype = get_quant_dtype_from_qwix_config(
|
|
528
|
-
vllm_config)
|
|
529
|
-
|
|
530
|
-
logger.info(
|
|
531
|
-
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
|
|
532
|
-
)
|
|
533
527
|
|
|
534
|
-
|
|
535
|
-
"weight_block_size"]
|
|
536
|
-
assert len(
|
|
537
|
-
quantization_block_sizes
|
|
538
|
-
) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
|
|
539
|
-
self.quantization_block_size_n = quantization_block_sizes[0]
|
|
540
|
-
self.quantization_block_size_k = quantization_block_sizes[1]
|
|
541
|
-
# TODO (jacobplatin): remove this check in the future
|
|
542
|
-
assert self.quantization_block_size_n == self.quantization_block_size_k, "Quantization block size n and k must be the same!"
|
|
543
|
-
# NOTE: this is only needed for pre-quantized models
|
|
544
|
-
self._scale_shape_map = {
|
|
545
|
-
"q_b_proj": (1, qk_nope_head_dim + qk_rope_head_dim,
|
|
546
|
-
q_lora_rank // self.quantization_block_size_n),
|
|
547
|
-
"kv_b_proj": (attn_heads, (qk_nope_head_dim + v_head_dim) //
|
|
548
|
-
self.quantization_block_size_n,
|
|
549
|
-
kv_lora_rank // self.quantization_block_size_n),
|
|
550
|
-
# used for MLA kernel
|
|
551
|
-
"k_b_proj":
|
|
552
|
-
(attn_heads,
|
|
553
|
-
qk_nope_head_dim // self.quantization_block_size_n,
|
|
554
|
-
kv_lora_rank // self.quantization_block_size_n),
|
|
555
|
-
# used for MLA kernel
|
|
556
|
-
"v_b_proj":
|
|
557
|
-
(attn_heads, v_head_dim // self.quantization_block_size_n,
|
|
558
|
-
kv_lora_rank // self.quantization_block_size_n),
|
|
559
|
-
"o_proj":
|
|
560
|
-
(hidden_size // self.quantization_block_size_n, attn_heads,
|
|
561
|
-
v_head_dim // self.quantization_block_size_n),
|
|
562
|
-
}
|
|
528
|
+
if self.is_model_quantized:
|
|
563
529
|
# NOTE: this is only needed for pre-quantized models when doing random weight loading
|
|
530
|
+
# because the scales that Qwix configures by default don't necessarily match the
|
|
531
|
+
# scales in practice
|
|
564
532
|
# TODO (jacobplatin): remove or clean this up
|
|
565
|
-
self.
|
|
566
|
-
|
|
567
|
-
"
|
|
568
|
-
"
|
|
569
|
-
"
|
|
570
|
-
|
|
571
|
-
"
|
|
572
|
-
"
|
|
533
|
+
self.scale_shape_map_for_random_weight_loading = {
|
|
534
|
+
# MoE experts (3D)
|
|
535
|
+
"custom_module.kernel_down_proj_EFD": (256, 8, 7168),
|
|
536
|
+
"custom_module.kernel_gating_EDF": (256, 28, 2048),
|
|
537
|
+
"custom_module.kernel_up_proj_EDF": (256, 28, 2048),
|
|
538
|
+
# Shared experts (2D)
|
|
539
|
+
"shared_experts.kernel_down_proj_FD": (8, 7168),
|
|
540
|
+
"shared_experts.kernel_gating_DF": (28, 2048),
|
|
541
|
+
"shared_experts.kernel_up_proj_DF": (28, 2048),
|
|
542
|
+
# Dense FFW (2D)
|
|
543
|
+
"custom_module.kernel_gating_DF": (28, 18432),
|
|
544
|
+
"custom_module.kernel_up_proj_DF": (28, 18432),
|
|
545
|
+
"custom_module.kernel_down_proj_FD": (72, 7168),
|
|
546
|
+
# Attention (3D for MLA, 2D for the rest)
|
|
547
|
+
"attn.kernel_q_down_proj_DA": (28, 1536),
|
|
548
|
+
"attn.kernel_q_up_proj_AP": (6, 24576),
|
|
549
|
+
"attn.kernel_kv_down_proj_DA": (28, 576),
|
|
550
|
+
"attn.kernel_kv_up_proj_AL": (2, 32768),
|
|
551
|
+
"attn.kernel_o_proj_RD": (64, 7168),
|
|
552
|
+
"attn.kernel_k_up_proj_ANH": (2, 128, 128), # MLA
|
|
553
|
+
"attn.kernel_v_up_proj_ANH": (2, 128, 128), # MLA
|
|
573
554
|
}
|
|
574
555
|
|
|
556
|
+
# TODO (jacobplatin): remove this check eventually!
|
|
557
|
+
assert self.quant_dtype == jnp.float8_e4m3fn, f"Expected quant_dtype to be float8_e4m3fn for DeepSeek but got {self.quant_dtype}"
|
|
558
|
+
|
|
575
559
|
def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
|
|
576
560
|
# Find the corresponding model key using the HF key
|
|
577
561
|
if "layer" in loaded_key:
|
|
@@ -649,45 +633,56 @@ class DeepSeekV3WeightLoader:
|
|
|
649
633
|
base_model_weight, "array") else base_model_weight.sharding
|
|
650
634
|
|
|
651
635
|
# Convert weights from torch into numpy
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
|
|
658
|
-
# raw data as integers before converting to a JAX array.
|
|
659
|
-
weight_np = jnp.array(
|
|
660
|
-
weight.view(torch_view_type).numpy()).view(cast_type)
|
|
636
|
+
if weight.dtype == torch.uint8 and scale is not None:
|
|
637
|
+
# Assume packed FP4 format when uint8 weights with scale provided
|
|
638
|
+
weight_jax_u8 = jnp.array(weight.cpu().numpy())
|
|
639
|
+
weight_np = u8_unpack_e2m1(weight_jax_u8)
|
|
640
|
+
scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)
|
|
661
641
|
else:
|
|
662
|
-
|
|
663
|
-
|
|
642
|
+
cast_type = model_weight.value.dtype
|
|
643
|
+
# Special-case: FP4 values stored as FP8 for compatibility.
|
|
644
|
+
# If the model expects float4_e2m1fn but the checkpoint provides FP8,
|
|
645
|
+
# convert by numeric value (float32) then cast to float4.
|
|
646
|
+
if cast_type == jnp.float4_e2m1fn and weight.dtype == torch.float8_e4m3fn:
|
|
647
|
+
weight_np = jnp.array(weight.float().numpy()).astype(cast_type)
|
|
648
|
+
else:
|
|
649
|
+
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
|
|
664
650
|
|
|
665
|
-
|
|
666
|
-
|
|
651
|
+
if torch_view_type:
|
|
652
|
+
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
|
|
653
|
+
# raw data as integers before converting to a JAX array.
|
|
654
|
+
weight_np = jnp.array(
|
|
655
|
+
weight.view(torch_view_type).numpy()).view(cast_type)
|
|
656
|
+
else:
|
|
657
|
+
raise ValueError(
|
|
658
|
+
f"Unsupported dtype for tensor conversion: {cast_type}"
|
|
659
|
+
)
|
|
667
660
|
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
scale = reshape_params(name, scale, self._scale_shape_map)
|
|
661
|
+
if scale is not None:
|
|
662
|
+
scale = scale.to(torch.float32).numpy().astype(
|
|
663
|
+
self.scale_dtype)
|
|
672
664
|
weight_np = self._transpose_params(name, weight_np)
|
|
673
665
|
if scale is not None:
|
|
674
666
|
scale = self._transpose_params(name, scale)
|
|
667
|
+
# Ensure scale is broadcastable to weight_np by repeating per-axis.
|
|
675
668
|
weight_shape = weight_np.shape
|
|
676
669
|
scale_shape = scale.shape
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
670
|
+
if len(weight_shape) == len(scale_shape):
|
|
671
|
+
new_scale = scale
|
|
672
|
+
for wdim, sdim in zip(weight_shape, scale_shape):
|
|
673
|
+
if (wdim % sdim != 0):
|
|
674
|
+
raise ValueError(
|
|
675
|
+
f"Weight dim {wdim} is not divisible by scale dim {sdim} for weight {name} with shape {weight_shape} and scale {scale_shape}!"
|
|
676
|
+
)
|
|
677
|
+
if scale_shape != new_scale.shape:
|
|
684
678
|
logger.warning(
|
|
685
|
-
f"
|
|
686
|
-
f"where the scale_dim {scale_dim} does not match the weight_dim {weight_dim} "
|
|
687
|
-
f"multiplied by the quantization block size {self.quantization_block_size_n}. "
|
|
688
|
-
f"Repeating the scale to new shape {scale.shape} along axis {idx} with repeat size {self.quantization_block_size_n}."
|
|
679
|
+
f"Adjusted scale shape {scale_shape} to {new_scale.shape} to match weight {weight_shape}"
|
|
689
680
|
)
|
|
690
|
-
|
|
681
|
+
scale = new_scale
|
|
682
|
+
else:
|
|
683
|
+
raise ValueError(
|
|
684
|
+
f"Scale rank {scale_shape} does not match weight rank {weight_shape}"
|
|
685
|
+
)
|
|
691
686
|
|
|
692
687
|
if model_weight.value.shape != weight_np.shape:
|
|
693
688
|
raise ValueError(
|
|
@@ -721,10 +716,8 @@ class DeepSeekV3WeightLoader:
|
|
|
721
716
|
logger.warning(
|
|
722
717
|
f"Could not create sharded scale for {name} with shape {scale.shape} and sharding {sharding}, skipping sharding..."
|
|
723
718
|
)
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, "Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
|
|
727
|
-
assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, "Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
|
|
719
|
+
assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, f"Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
|
|
720
|
+
assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, f"Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
|
|
728
721
|
base_model_weight.array.scale.value = maybe_sharded_scale
|
|
729
722
|
base_model_weight.array.qvalue.value = sharded_array
|
|
730
723
|
else:
|
|
@@ -790,7 +783,11 @@ class DeepSeekV3WeightLoader:
|
|
|
790
783
|
# TODO (jacobplatin): refactor this so that we instead change / update `model_weights_generator`
|
|
791
784
|
# instead of checking "weight_scale_inv" and assuming quantization method is fp8
|
|
792
785
|
scale = None
|
|
793
|
-
|
|
786
|
+
# Mixed quantization: accept both fp8 and packed fp4 (uint8) tensors
|
|
787
|
+
allowed_quant_dtypes = {
|
|
788
|
+
j2t_dtype(self.quant_dtype.dtype), torch.uint8
|
|
789
|
+
}
|
|
790
|
+
if loaded_weight.dtype in allowed_quant_dtypes:
|
|
794
791
|
if self.is_model_quantized:
|
|
795
792
|
scale_name = loaded_name.replace(
|
|
796
793
|
".weight", ".weight_scale_inv")
|
|
@@ -880,11 +877,9 @@ class DeepSeekV3WeightLoader:
|
|
|
880
877
|
self.qk_nope_head_dim + self.v_head_dim,
|
|
881
878
|
self.kv_lora_rank)
|
|
882
879
|
k_weight = weight_reshaped[:, :self.
|
|
883
|
-
qk_nope_head_dim, :]
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
qk_nope_head_dim:, :].reshape(
|
|
887
|
-
-1, self.kv_lora_rank)
|
|
880
|
+
qk_nope_head_dim, :]
|
|
881
|
+
v_weight = weight_reshaped[:,
|
|
882
|
+
self.qk_nope_head_dim:, :]
|
|
888
883
|
|
|
889
884
|
loaded_weights_list = [k_weight, v_weight]
|
|
890
885
|
loaded_names = [
|
|
@@ -894,25 +889,19 @@ class DeepSeekV3WeightLoader:
|
|
|
894
889
|
|
|
895
890
|
scales_list = [None, None]
|
|
896
891
|
if scale is not None:
|
|
897
|
-
|
|
898
|
-
|
|
892
|
+
assert loaded_weight.shape[0] == scale.shape[0]
|
|
893
|
+
block_size_k = loaded_weight.shape[
|
|
894
|
+
1] // scale.shape[1]
|
|
895
|
+
assert block_size_k > 0, f"Expected non-zero block size but got {block_size_k}!"
|
|
899
896
|
scale_reshaped = scale.view(
|
|
900
897
|
self.attn_heads,
|
|
901
|
-
(self.qk_nope_head_dim + self.v_head_dim)
|
|
902
|
-
|
|
898
|
+
(self.qk_nope_head_dim + self.v_head_dim),
|
|
899
|
+
self.kv_lora_rank // block_size_k)
|
|
903
900
|
|
|
904
901
|
k_scale = scale_reshaped[:, :self.
|
|
905
|
-
qk_nope_head_dim
|
|
906
|
-
bn, :].reshape(
|
|
907
|
-
-1,
|
|
908
|
-
self.kv_lora_rank //
|
|
909
|
-
bk)
|
|
902
|
+
qk_nope_head_dim, :]
|
|
910
903
|
v_scale = scale_reshaped[:,
|
|
911
|
-
self.qk_nope_head_dim
|
|
912
|
-
bn:, :].reshape(
|
|
913
|
-
-1,
|
|
914
|
-
self.kv_lora_rank //
|
|
915
|
-
bk)
|
|
904
|
+
self.qk_nope_head_dim:, :]
|
|
916
905
|
scales_list = [k_scale, v_scale]
|
|
917
906
|
|
|
918
907
|
else:
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import re
|
|
2
16
|
from dataclasses import dataclass
|
|
3
17
|
from typing import List, Optional, Tuple
|
|
@@ -11,6 +25,9 @@ from jax.sharding import Mesh, NamedSharding
|
|
|
11
25
|
from jax.sharding import PartitionSpec as P
|
|
12
26
|
from vllm.config import VllmConfig
|
|
13
27
|
|
|
28
|
+
from tpu_inference.layers.common.quant_methods import MXFP4
|
|
29
|
+
from tpu_inference.layers.common.quantization import (
|
|
30
|
+
dequantize_tensor_from_mxfp4_packed, e8m0_to_fp32, u8_unpack_e2m1)
|
|
14
31
|
from tpu_inference.layers.jax.attention.gpt_oss_attention import (
|
|
15
32
|
AttentionMetadata, GptOssAttention)
|
|
16
33
|
from tpu_inference.layers.jax.constants import KVCacheType
|
|
@@ -18,8 +35,6 @@ from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
|
|
|
18
35
|
from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
|
|
19
36
|
from tpu_inference.layers.jax.transformer_block import TransformerBlock
|
|
20
37
|
from tpu_inference.logger import init_logger
|
|
21
|
-
from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
|
|
22
|
-
MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
|
|
23
38
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
24
39
|
get_param, model_weights_generator, print_param_info)
|
|
25
40
|
|
|
@@ -205,7 +220,7 @@ class GptOss(nnx.Module):
|
|
|
205
220
|
|
|
206
221
|
# MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
|
|
207
222
|
swap_mlp_transform = transforms[
|
|
208
|
-
"swap_last2"] if quant_method ==
|
|
223
|
+
"swap_last2"] if quant_method == MXFP4 else None
|
|
209
224
|
|
|
210
225
|
mappings = {
|
|
211
226
|
# Embeddings, Norms, and LM Head
|
|
@@ -285,7 +300,7 @@ class GptOss(nnx.Module):
|
|
|
285
300
|
# Build a pool of weights with MXFP4 experts combined if neededs
|
|
286
301
|
pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
|
|
287
302
|
names_and_weights_generator,
|
|
288
|
-
mappings) if quant_method ==
|
|
303
|
+
mappings) if quant_method == MXFP4 else {
|
|
289
304
|
loaded_name: loaded_weight
|
|
290
305
|
for loaded_name, loaded_weight in names_and_weights_generator
|
|
291
306
|
})
|
|
@@ -316,8 +331,9 @@ class GptOss(nnx.Module):
|
|
|
316
331
|
blocks_u8, scales_u8 = loaded_weight
|
|
317
332
|
# Quantized param (QArray): set qvalue/scale directly and skip regular path
|
|
318
333
|
if hasattr(model_weight, "array"): # QArray check
|
|
319
|
-
codes_fp32_t
|
|
320
|
-
|
|
334
|
+
codes_fp32_t = u8_unpack_e2m1(blocks_u8).astype(
|
|
335
|
+
jnp.float32)
|
|
336
|
+
scales_fp32_t = e8m0_to_fp32(scales_u8)
|
|
321
337
|
self._load_mxfp4(
|
|
322
338
|
model_weight=model_weight,
|
|
323
339
|
codes_fp32_t=codes_fp32_t,
|
|
@@ -328,7 +344,7 @@ class GptOss(nnx.Module):
|
|
|
328
344
|
print_param_info(model_weight, loaded_name)
|
|
329
345
|
continue
|
|
330
346
|
# Not a QArray: dequantize MXFP4 to BF16 full weights
|
|
331
|
-
prepared_weight =
|
|
347
|
+
prepared_weight = dequantize_tensor_from_mxfp4_packed(
|
|
332
348
|
blocks_u8, scales_u8)
|
|
333
349
|
|
|
334
350
|
# Single regular-tensor load call (BF16 or dequantized MXFP4)
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import dataclass
|
|
2
16
|
from typing import TYPE_CHECKING, Any, Dict, Union
|
|
3
17
|
|