tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import math
|
|
2
16
|
from functools import partial
|
|
3
17
|
from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
|
|
@@ -486,6 +500,11 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
486
500
|
dtype=dtype,
|
|
487
501
|
rngs=rngs)
|
|
488
502
|
|
|
503
|
+
additional_config = getattr(vllm_config, "additional_config",
|
|
504
|
+
None) or {}
|
|
505
|
+
self.enable_dynamic_image_sizes = additional_config.get(
|
|
506
|
+
"enable_dynamic_image_sizes", False)
|
|
507
|
+
|
|
489
508
|
def rotary_pos_emb_thw(self, t, h, w):
|
|
490
509
|
hpos_ids, wpos_ids = jnp.indices((h, w))
|
|
491
510
|
hpos_ids = hpos_ids.reshape(
|
|
@@ -579,21 +598,7 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
579
598
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
580
599
|
return max_seqlen, seqlens
|
|
581
600
|
|
|
582
|
-
def
|
|
583
|
-
int]]) -> jax.Array:
|
|
584
|
-
# x: pixel_values: jax.Array
|
|
585
|
-
# """Shape:
|
|
586
|
-
# `(num_patches, num_channels * patch_size * patch_size)`
|
|
587
|
-
# """
|
|
588
|
-
|
|
589
|
-
# grid_thw: image_grid_thw: jax.Array
|
|
590
|
-
# """Shape: `(num_images, 3)`
|
|
591
|
-
# This should be in `(grid_t, grid_h, grid_w)` format.
|
|
592
|
-
# """
|
|
593
|
-
hidden_states = self.patch_embed(x)
|
|
594
|
-
|
|
595
|
-
# num of patches
|
|
596
|
-
seq_len = x.shape[0]
|
|
601
|
+
def compute_aux_arrays(self, grid_thw: tuple[tuple[int, int, int]]):
|
|
597
602
|
# num of images/videoes
|
|
598
603
|
num_grids = len(grid_thw)
|
|
599
604
|
|
|
@@ -638,6 +643,42 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
638
643
|
cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
|
|
639
644
|
mode='constant',
|
|
640
645
|
constant_values=0)
|
|
646
|
+
return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
|
|
647
|
+
|
|
648
|
+
def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
|
|
649
|
+
cu_window_seqlens):
|
|
650
|
+
# padding
|
|
651
|
+
num_patches = int(rotary_pos_emb.shape[0])
|
|
652
|
+
bucket_num_patches = 1 << (num_patches - 1).bit_length()
|
|
653
|
+
num_tokens = window_index.shape[0]
|
|
654
|
+
bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
|
|
655
|
+
vit_merger_window_size = (self.window_size //
|
|
656
|
+
self.spatial_merge_size // self.patch_size)
|
|
657
|
+
max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
|
|
658
|
+
|
|
659
|
+
rotary_pos_emb = jnp.pad(rotary_pos_emb,
|
|
660
|
+
((0, bucket_num_patches - num_patches),
|
|
661
|
+
(0, 0)))
|
|
662
|
+
window_index = jnp.concatenate([
|
|
663
|
+
window_index,
|
|
664
|
+
jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
|
|
665
|
+
])
|
|
666
|
+
cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
|
|
667
|
+
pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
|
|
668
|
+
cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
|
|
669
|
+
cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
|
|
670
|
+
|
|
671
|
+
x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
|
|
672
|
+
|
|
673
|
+
return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
|
|
674
|
+
|
|
675
|
+
def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
|
|
676
|
+
rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
|
|
677
|
+
cu_window_seqlens: jax.Array) -> jax.Array:
|
|
678
|
+
hidden_states = self.patch_embed(x)
|
|
679
|
+
|
|
680
|
+
# num of patches
|
|
681
|
+
seq_len = x.shape[0]
|
|
641
682
|
|
|
642
683
|
hidden_states = hidden_states.reshape(
|
|
643
684
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
@@ -664,6 +705,48 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
664
705
|
hidden_states = hidden_states[reverse_indices, :]
|
|
665
706
|
return hidden_states
|
|
666
707
|
|
|
708
|
+
@jax.jit
|
|
709
|
+
def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
|
|
710
|
+
cu_seqlens, cu_window_seqlens):
|
|
711
|
+
return self.compute_hidden_states(x_padded, window_index,
|
|
712
|
+
rotary_pos_emb, cu_seqlens,
|
|
713
|
+
cu_window_seqlens)
|
|
714
|
+
|
|
715
|
+
@partial(
|
|
716
|
+
jax.jit,
|
|
717
|
+
static_argnames=("grid_thw", ),
|
|
718
|
+
)
|
|
719
|
+
def encode_jit(self, x, grid_thw):
|
|
720
|
+
window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
|
|
721
|
+
grid_thw)
|
|
722
|
+
return self.compute_hidden_states(x, window_index, rotary_pos_emb,
|
|
723
|
+
cu_seqlens, cu_window_seqlens)
|
|
724
|
+
|
|
725
|
+
def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
|
|
726
|
+
int]]) -> jax.Array:
|
|
727
|
+
# x: pixel_values: jax.Array
|
|
728
|
+
# """Shape:
|
|
729
|
+
# `(num_patches, num_channels * patch_size * patch_size)`
|
|
730
|
+
# """
|
|
731
|
+
|
|
732
|
+
# grid_thw: image_grid_thw: jax.Array
|
|
733
|
+
# """Shape: `(num_images, 3)`
|
|
734
|
+
# This should be in `(grid_t, grid_h, grid_w)` format.
|
|
735
|
+
# """
|
|
736
|
+
if self.enable_dynamic_image_sizes:
|
|
737
|
+
window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
|
|
738
|
+
grid_thw)
|
|
739
|
+
x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
|
|
740
|
+
x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
|
|
741
|
+
|
|
742
|
+
hidden_states = self.encode_padded_jit(x_padded, window_index,
|
|
743
|
+
rotary_pos_emb, cu_seqlens,
|
|
744
|
+
cu_window_seqlens)
|
|
745
|
+
return hidden_states[:num_tokens]
|
|
746
|
+
|
|
747
|
+
else:
|
|
748
|
+
return self.encode_jit(x, grid_thw)
|
|
749
|
+
|
|
667
750
|
|
|
668
751
|
class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
669
752
|
|
|
@@ -888,10 +971,6 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
888
971
|
# "video"] = self._parse_and_validate_video_input(**kwargs)
|
|
889
972
|
return mm_input_by_modality
|
|
890
973
|
|
|
891
|
-
@partial(
|
|
892
|
-
jax.jit,
|
|
893
|
-
static_argnames=("image_grid_thw", ),
|
|
894
|
-
)
|
|
895
974
|
def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
|
|
896
975
|
return self.visual(image_pixel_values, (image_grid_thw, ))
|
|
897
976
|
|
|
@@ -931,9 +1010,9 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
931
1010
|
split_indices = np.cumsum(sizes)[:-1]
|
|
932
1011
|
return tuple(jnp.split(image_embeds, split_indices))
|
|
933
1012
|
|
|
934
|
-
def
|
|
935
|
-
|
|
936
|
-
|
|
1013
|
+
def embed_multimodal(self, image_grid_thw: tuple[tuple[int, int, int],
|
|
1014
|
+
...],
|
|
1015
|
+
**kwargs: object) -> MultiModalEmbeddings:
|
|
937
1016
|
|
|
938
1017
|
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
|
939
1018
|
image_grid_thw, **kwargs)
|
|
@@ -957,7 +1036,7 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
957
1036
|
|
|
958
1037
|
return multimodal_embeddings
|
|
959
1038
|
|
|
960
|
-
def
|
|
1039
|
+
def embed_input_ids(
|
|
961
1040
|
self, input_ids: jax.Array,
|
|
962
1041
|
multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
|
|
963
1042
|
|
|
@@ -1072,33 +1151,82 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1072
1151
|
self,
|
|
1073
1152
|
run_compilation_fn: Callable,
|
|
1074
1153
|
) -> None:
|
|
1075
|
-
image_shapes = []
|
|
1076
|
-
if (warmup_config := self.vllm_config.additional_config.get(
|
|
1077
|
-
"vision_warmup_config")):
|
|
1078
|
-
image_shapes = warmup_config.get("image_shapes")
|
|
1079
|
-
|
|
1080
1154
|
vc = self.vllm_config.model_config.hf_config.vision_config
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1155
|
+
patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
1156
|
+
if self.visual.enable_dynamic_image_sizes:
|
|
1157
|
+
spatial_merge_unit = vc.spatial_merge_size**2
|
|
1158
|
+
max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
|
|
1159
|
+
mm_kwargs = self.vllm_config.model_config.multimodal_config.mm_processor_kwargs or {}
|
|
1160
|
+
limit_pixels = float(mm_kwargs.get("max_pixels", float('inf')))
|
|
1161
|
+
|
|
1162
|
+
max_patches = int(
|
|
1163
|
+
min(max_num_batched_tokens * spatial_merge_unit,
|
|
1164
|
+
limit_pixels / (vc.patch_size**2)))
|
|
1165
|
+
|
|
1166
|
+
num_patches_paddings = [
|
|
1167
|
+
1 << i for i in range(4, (max_patches - 1).bit_length() + 1)
|
|
1168
|
+
]
|
|
1169
|
+
rotary_dim = vc.hidden_size // vc.num_heads // 2
|
|
1170
|
+
vit_merger_window_size = (vc.window_size //
|
|
1171
|
+
vc.spatial_merge_size // vc.patch_size)
|
|
1172
|
+
|
|
1173
|
+
for num_patches in num_patches_paddings:
|
|
1174
|
+
dummy_x_padded = jnp.ones(
|
|
1175
|
+
(num_patches, patch_input_dim),
|
|
1176
|
+
dtype=self.vllm_config.model_config.dtype)
|
|
1177
|
+
|
|
1178
|
+
num_tokens = num_patches // spatial_merge_unit
|
|
1179
|
+
dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
|
|
1180
|
+
|
|
1181
|
+
dummy_rotary_pos_emb = jnp.ones(
|
|
1182
|
+
(num_patches, rotary_dim),
|
|
1183
|
+
dtype=self.vllm_config.model_config.dtype)
|
|
1184
|
+
|
|
1185
|
+
dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
|
|
1186
|
+
dtype=jnp.int32)
|
|
1187
|
+
|
|
1188
|
+
max_windows = (num_tokens // vit_merger_window_size) + 2
|
|
1189
|
+
patches_per_window = (vit_merger_window_size**
|
|
1190
|
+
2) * spatial_merge_unit
|
|
1191
|
+
dummy_cu_window_seqlens = jnp.arange(
|
|
1192
|
+
max_windows + 1, dtype=jnp.int32) * patches_per_window
|
|
1193
|
+
dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
|
|
1194
|
+
num_patches)
|
|
1195
|
+
|
|
1196
|
+
run_compilation_fn("vision_encoder_padded",
|
|
1197
|
+
self.visual.encode_padded_jit,
|
|
1198
|
+
dummy_x_padded,
|
|
1199
|
+
dummy_window_index,
|
|
1200
|
+
dummy_rotary_pos_emb,
|
|
1201
|
+
dummy_cu_seqlens,
|
|
1202
|
+
dummy_cu_window_seqlens,
|
|
1203
|
+
num_patches=num_patches)
|
|
1204
|
+
else:
|
|
1205
|
+
image_shapes = []
|
|
1206
|
+
if (warmup_config := self.vllm_config.additional_config.get(
|
|
1207
|
+
"vision_warmup_config")):
|
|
1208
|
+
image_shapes = warmup_config.get("image_shapes")
|
|
1209
|
+
|
|
1210
|
+
factor = vc.patch_size * vc.spatial_merge_size
|
|
1211
|
+
for input_hw in image_shapes:
|
|
1212
|
+
if not isinstance(input_hw, list) or len(input_hw) != 2:
|
|
1213
|
+
logger.warning(f"Skipping invalid shape {input_hw}.")
|
|
1214
|
+
continue
|
|
1215
|
+
h_input, w_input = input_hw
|
|
1216
|
+
h_processed = round(h_input / factor) * factor
|
|
1217
|
+
w_processed = round(w_input / factor) * factor
|
|
1218
|
+
t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
|
|
1219
|
+
grid_thw = (t, h, w)
|
|
1220
|
+
num_patches = t * h * w
|
|
1221
|
+
|
|
1222
|
+
dummy_pixel_values = jnp.ones(
|
|
1223
|
+
(num_patches, patch_input_dim),
|
|
1224
|
+
self.vllm_config.model_config.dtype,
|
|
1225
|
+
)
|
|
1226
|
+
dummy_grid_thw = (grid_thw, )
|
|
1099
1227
|
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1228
|
+
run_compilation_fn("vision_encoder",
|
|
1229
|
+
self.visual.encode_jit,
|
|
1230
|
+
dummy_pixel_values,
|
|
1231
|
+
dummy_grid_thw,
|
|
1232
|
+
image_shape=input_hw)
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import List, Optional, Tuple
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
|
|
|
10
24
|
from tpu_inference import utils
|
|
11
25
|
from tpu_inference.layers.common.attention_interface import attention
|
|
12
26
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
13
28
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
29
|
from tpu_inference.logger import init_logger
|
|
15
30
|
from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
|
|
@@ -125,8 +140,8 @@ class Qwen3Attention(nnx.Module):
|
|
|
125
140
|
# q_scale = self._q_scale
|
|
126
141
|
k_scale = self._k_scale
|
|
127
142
|
v_scale = self._v_scale
|
|
128
|
-
k, v =
|
|
129
|
-
|
|
143
|
+
k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
|
|
144
|
+
v_scale)
|
|
130
145
|
new_kv_cache, outputs = attention(
|
|
131
146
|
kv_cache,
|
|
132
147
|
q,
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import glob
|
|
2
16
|
import hashlib
|
|
3
17
|
import os
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -29,25 +43,25 @@ def sanity_check_mm_encoder_outputs(
|
|
|
29
43
|
) -> None:
|
|
30
44
|
"""
|
|
31
45
|
Perform sanity checks for the result of
|
|
32
|
-
[`vllm.model_executor.models.SupportsMultiModal.
|
|
46
|
+
[`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
|
|
33
47
|
"""
|
|
34
48
|
assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
|
|
35
49
|
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
|
36
50
|
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
|
37
51
|
"instead. This is most likely due to incorrect implementation "
|
|
38
|
-
"of the model's `
|
|
52
|
+
"of the model's `embed_multimodal` method.")
|
|
39
53
|
|
|
40
54
|
assert len(mm_embeddings) == expected_num_items, (
|
|
41
55
|
"Expected number of multimodal embeddings to match number of "
|
|
42
56
|
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
|
43
57
|
"instead. This is most likely due to incorrect implementation "
|
|
44
|
-
"of the model's `
|
|
58
|
+
"of the model's `embed_multimodal` method.")
|
|
45
59
|
|
|
46
60
|
assert all(e.ndim == 2 for e in mm_embeddings), (
|
|
47
61
|
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
|
48
62
|
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
|
49
63
|
"instead. This is most likely due to incorrect implementation "
|
|
50
|
-
"of the model's `
|
|
64
|
+
"of the model's `embed_multimodal` method.")
|
|
51
65
|
|
|
52
66
|
|
|
53
67
|
def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import copy
|
|
3
4
|
import functools
|
|
4
5
|
import os
|
|
5
6
|
from typing import TYPE_CHECKING, Callable, List
|
|
@@ -34,17 +35,43 @@ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
|
|
|
34
35
|
DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
|
|
35
36
|
DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
|
|
36
37
|
|
|
37
|
-
|
|
38
|
+
DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG = {
|
|
38
39
|
"qwix": {
|
|
39
40
|
"use_abstract_model":
|
|
40
41
|
True,
|
|
41
42
|
"scale_dtype":
|
|
42
43
|
"bfloat16",
|
|
43
44
|
"rules": [
|
|
45
|
+
# Exclude router from quantization
|
|
44
46
|
{
|
|
45
47
|
"module_path": ".*.custom_module.router.*",
|
|
46
48
|
"weight_qtype": None,
|
|
47
49
|
},
|
|
50
|
+
# Avoid the combine expert ops
|
|
51
|
+
{
|
|
52
|
+
"module_path": ".*combine_experts.*",
|
|
53
|
+
"weight_qtype": None,
|
|
54
|
+
},
|
|
55
|
+
# Attention layers: keep FP8 for weights and activations
|
|
56
|
+
{
|
|
57
|
+
"module_path": ".*.attn.*",
|
|
58
|
+
"weight_qtype": "float8_e4m3fn",
|
|
59
|
+
"act_qtype": "float8_e4m3fn",
|
|
60
|
+
},
|
|
61
|
+
# MoE experts: use FP4 for expert weights
|
|
62
|
+
{
|
|
63
|
+
"module_path": ".*.custom_module.*",
|
|
64
|
+
"weight_qtype": "float4_e2m1fn",
|
|
65
|
+
"act_qtype": "float8_e4m3fn",
|
|
66
|
+
"tile_size": 256,
|
|
67
|
+
},
|
|
68
|
+
# Shared experts: also FP4
|
|
69
|
+
{
|
|
70
|
+
"module_path": ".*.shared_experts.*",
|
|
71
|
+
"weight_qtype": "float4_e2m1fn",
|
|
72
|
+
"act_qtype": "float8_e4m3fn",
|
|
73
|
+
"tile_size": 256,
|
|
74
|
+
},
|
|
48
75
|
{
|
|
49
76
|
"module_path": ".*",
|
|
50
77
|
"weight_qtype": "float8_e4m3fn",
|
|
@@ -154,12 +181,9 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
|
|
|
154
181
|
logger.info(f"Memory usage before applying quantization of params: "
|
|
155
182
|
f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
|
|
156
183
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
# Handle the case where kv_cache_dtype is "auto"
|
|
161
|
-
if kv_cache_jnp_dtype is None:
|
|
162
|
-
assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
|
|
184
|
+
if kv_cache_dtype != "auto":
|
|
185
|
+
kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
|
|
186
|
+
else:
|
|
163
187
|
kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
|
|
164
188
|
|
|
165
189
|
kv_caches = create_kv_caches(
|
|
@@ -169,9 +193,11 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
|
|
|
169
193
|
head_size=kv_cache_head_size,
|
|
170
194
|
mesh=mesh,
|
|
171
195
|
layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
|
|
172
|
-
cache_dtype=kv_cache_jnp_dtype
|
|
196
|
+
cache_dtype=kv_cache_jnp_dtype,
|
|
197
|
+
use_mla=model.vllm_config.model_config.use_mla,
|
|
198
|
+
)
|
|
173
199
|
|
|
174
|
-
dp_size =
|
|
200
|
+
dp_size = model.vllm_config.sharding_config.total_dp_size
|
|
175
201
|
|
|
176
202
|
# NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
|
|
177
203
|
input_ids = jax.random.randint(rng,
|
|
@@ -399,8 +425,7 @@ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
|
|
|
399
425
|
|
|
400
426
|
|
|
401
427
|
def get_default_qwix_quantization_config(
|
|
402
|
-
|
|
403
|
-
skip_quantization: bool) -> dict | None:
|
|
428
|
+
hf_config: dict, skip_quantization: bool) -> dict | None:
|
|
404
429
|
"""
|
|
405
430
|
Some models are pre-quantized and in those cases, we want to return a default set of
|
|
406
431
|
Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
|
|
@@ -418,9 +443,42 @@ def get_default_qwix_quantization_config(
|
|
|
418
443
|
"""
|
|
419
444
|
if skip_quantization:
|
|
420
445
|
return None
|
|
421
|
-
|
|
446
|
+
model_type = hf_config.model_type.lower() if hasattr(
|
|
447
|
+
hf_config, "model_type") else None
|
|
448
|
+
quant_method = hf_config.quantization_config["quant_method"] if hasattr(
|
|
449
|
+
hf_config, "quantization_config") else None
|
|
450
|
+
# TODO (jacobplatin): remove this so that we can support various quantization types + make
|
|
451
|
+
# more flexible
|
|
452
|
+
# NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
|
|
453
|
+
# for DeepSeek
|
|
422
454
|
if model_type == "deepseek_v3" and quant_method == "fp8":
|
|
423
|
-
|
|
455
|
+
config = copy.deepcopy(DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG)
|
|
456
|
+
|
|
457
|
+
# Dynamically fetch block size from HF config if available
|
|
458
|
+
# Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
|
|
459
|
+
# NOTE: if the checkpoint is not 1D subchannel, we will throw an error
|
|
460
|
+
hf_quant_config = hf_config.quantization_config
|
|
461
|
+
assert "weight_block_size" in hf_quant_config, "Expected weight_block_size in quantization_config"
|
|
462
|
+
block_size = hf_quant_config["weight_block_size"]
|
|
463
|
+
if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
|
|
464
|
+
assert block_size[
|
|
465
|
+
0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}! If you are trying to run quantized DeepSeek, we currently only support 1D-subchannel quantization and those models can be found here: https://huggingface.co/collections/jrplatin/deepseek-r1-1d-subchannel"
|
|
466
|
+
tile_size = block_size[1]
|
|
467
|
+
assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
|
|
468
|
+
logger.info(
|
|
469
|
+
f"Detected DeepSeek tile_size from config: {tile_size}")
|
|
470
|
+
|
|
471
|
+
# Update tile_size in the rules, since we might not always use a 1D subchannel size of
|
|
472
|
+
# 256
|
|
473
|
+
for rule in config["qwix"]["rules"]:
|
|
474
|
+
if "tile_size" in rule:
|
|
475
|
+
rule["tile_size"] = tile_size
|
|
476
|
+
else:
|
|
477
|
+
raise ValueError(
|
|
478
|
+
f"Invalid weight_block_size config: {block_size}, expected a list/tuple of length 2"
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
return config
|
|
424
482
|
elif model_type == "llama4" and quant_method == "compressed-tensors":
|
|
425
483
|
return DEFAULT_LLAMA4_FP8_CONFIG
|
|
426
484
|
# MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
|
|
@@ -439,14 +497,10 @@ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
|
|
|
439
497
|
# Qwix quantization config accordingly
|
|
440
498
|
# NOTE: if a Qwix config is provided (via the`additional_config`), we'll
|
|
441
499
|
# use that instead
|
|
442
|
-
|
|
443
|
-
) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
|
|
444
|
-
quant_method = vllm_config.model_config.hf_config.quantization_config[
|
|
445
|
-
"quant_method"] if hasattr(vllm_config.model_config.hf_config,
|
|
446
|
-
"quantization_config") else None
|
|
500
|
+
hf_config = vllm_config.model_config.hf_config
|
|
447
501
|
default_quantization_config = get_default_qwix_quantization_config(
|
|
448
|
-
|
|
449
|
-
|
|
502
|
+
hf_config, vllm_config.additional_config.get("skip_quantization",
|
|
503
|
+
False))
|
|
450
504
|
|
|
451
505
|
maybe_existing_quantization_config = vllm_config.additional_config.get(
|
|
452
506
|
"quantization")
|
|
@@ -503,7 +557,14 @@ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
|
|
|
503
557
|
maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
|
|
504
558
|
weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
|
|
505
559
|
else:
|
|
506
|
-
|
|
560
|
+
# NOTE: _uniform() in random.py does not accept float4_e2m1fn
|
|
561
|
+
# Error: "TypeError: uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot float4_e2m1fn."
|
|
562
|
+
# Workaround: call function with dtype jnp.float8_e4m3fn and cast back to float4_e2m1fn
|
|
563
|
+
if dtype != "float4_e2m1fn":
|
|
564
|
+
weight = jax.random.normal(key, param_shape, dtype)
|
|
565
|
+
else:
|
|
566
|
+
weight = jax.random.normal(key, param_shape,
|
|
567
|
+
jnp.float8_e4m3fn).astype(dtype)
|
|
507
568
|
|
|
508
569
|
def get_slice(index):
|
|
509
570
|
return weight[index]
|
|
@@ -538,18 +599,16 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
|
|
|
538
599
|
logger.info("Initializing Qwix-quantized model with random weights...")
|
|
539
600
|
# TODO (jacobplatin): clean up this logic
|
|
540
601
|
scale_dtype = model.weight_loader.scale_dtype
|
|
541
|
-
scale_shape_map = model.weight_loader.
|
|
602
|
+
scale_shape_map = model.weight_loader.scale_shape_map_for_random_weight_loading if hasattr(
|
|
542
603
|
model.weight_loader,
|
|
543
|
-
'
|
|
604
|
+
'scale_shape_map_for_random_weight_loading') else {}
|
|
544
605
|
quantization_block_sizes = quantization_config["weight_block_size"]
|
|
545
606
|
assert len(
|
|
546
607
|
quantization_block_sizes
|
|
547
608
|
) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
|
|
548
|
-
quantization_block_size_n, _ = quantization_block_sizes[
|
|
549
|
-
0], quantization_block_sizes[1]
|
|
550
609
|
|
|
551
610
|
# Iterate through all variables and initialize them
|
|
552
|
-
|
|
611
|
+
|
|
553
612
|
for path, param in nnx.iter_graph(model):
|
|
554
613
|
if not isinstance(param, nnx.Variable):
|
|
555
614
|
continue
|
|
@@ -559,16 +618,17 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
|
|
|
559
618
|
is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
|
|
560
619
|
param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
|
|
561
620
|
param_shape = param.value.shape
|
|
562
|
-
# TODO (jacobplatin): clean this up
|
|
563
621
|
if is_qwix_scale:
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
622
|
+
key = f"{path[2]}.{path[3]}"
|
|
623
|
+
|
|
624
|
+
if key in scale_shape_map:
|
|
625
|
+
param_shape = scale_shape_map[key]
|
|
626
|
+
else:
|
|
627
|
+
raise ValueError(
|
|
628
|
+
f"Scale shape for {key} not found in scale_shape_map.")
|
|
568
629
|
param.value = get_random_sharded_array(
|
|
569
630
|
rng, mesh, param, param_shape, param_dtype,
|
|
570
631
|
".".join([str(x) for x in path]))
|
|
571
|
-
prev_param_shape = param_shape
|
|
572
632
|
|
|
573
633
|
# Handles the DeepSeek case, where this needs to be called to make the cache weights
|
|
574
634
|
# concrete
|