tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
tests/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
2
|
+
tests/test_base.py,sha256=47IflI4nIktZHlcmeqhmX9IdTKofg7OgsOiCyUTXlLw,7916
|
|
3
|
+
tests/test_envs.py,sha256=v0_R-HfWRNY8ssPqFrytHMl1irohJaTpS_rSKo2FZaY,10021
|
|
4
|
+
tests/test_tpu_info.py,sha256=OrA0Fbs9uCVqd8w7dqlGA_8KZArriyltqrCWf3hDDDU,5245
|
|
5
|
+
tests/test_utils.py,sha256=FF_41NL1VmUXDVvKr9eZg_juprqtHlUqSPR6Sisftdo,6309
|
|
6
|
+
tests/core/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
7
|
+
tests/core/test_core_tpu.py,sha256=r496rk1eOsK_F4nvm9zprl_T-RcO6eCUb7LuVReOZno,21413
|
|
8
|
+
tests/core/test_disagg_executor.py,sha256=QdE2YZs08EyDDCmSjhiXkXqQ9BJTgO6csr_E1xkkfSg,2256
|
|
9
|
+
tests/core/test_disagg_utils.py,sha256=A5icdqkJlau2PHYAxHfHKuqrlEKXVJu2nm02XOrXjcc,2530
|
|
10
|
+
tests/core/test_dp_scheduler.py,sha256=m6ph_OH9tXz6AxNde8cIjptd1lwDVSCqIV2Ef-cNJFk,34253
|
|
11
|
+
tests/core/test_init.py,sha256=5BDDC-dmDtWEGaBPjQSiYJuMiwTBVRSDx9p7Cv8DKyI,2262
|
|
12
|
+
tests/distributed/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
13
|
+
tests/distributed/test_distributed_utils.py,sha256=YXKbSG9J72vCrU5mPiFf1ya-Yzc1BjeahdBmQVez8Wc,5031
|
|
14
|
+
tests/distributed/test_tpu_connector.py,sha256=ajKeRUi3x29hQXfLrSlo6yDczpwZsg_mGt2vKBGRZdk,20538
|
|
15
|
+
tests/e2e/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
16
|
+
tests/e2e/test_async_scheduler.py,sha256=215xGuyTEBSOe-c1l48TIjrCqhbVFZY3m5p3q5mU7jA,6905
|
|
17
|
+
tests/e2e/test_data_parallel.py,sha256=KB-_BKic_iZyn4WbPWsUdVClinzd8g7PrQ0ui5B-nwo,10725
|
|
18
|
+
tests/e2e/test_hybrid_kvcache.py,sha256=Y7a-grjvAKBbp7vbQncVEQKGM1WxcwO0qa2o0opKiEI,8076
|
|
19
|
+
tests/e2e/test_local_disagg.py,sha256=xIjYI6RGA6bZk4dluklhfYBoJGbHkrSihSkJtPgpZv4,10434
|
|
20
|
+
tests/e2e/test_model_loader.py,sha256=DYlS420KXkNzeIijAf-0UQsYH0pOAGcXRl6P99PBiAc,9366
|
|
21
|
+
tests/e2e/test_multi_modal_inference.py,sha256=hVatj8Rra6XAekp6zBxRivQUcGiV8SimPph9cZ-TJyk,3896
|
|
22
|
+
tests/e2e/test_pipeline_parallel.py,sha256=VpxY9wgQj3-i0XooHZHdmHGdMS3ilmHbxu6ZfyQDUP0,9519
|
|
23
|
+
tests/e2e/test_runai_model_streamer_loader.py,sha256=MXUxKfKV7vVM_LI7-5hBV-wCswogPENkMPsREUjFu3I,3790
|
|
24
|
+
tests/e2e/test_sampling_params.py,sha256=ibLWtJfS35HughdOBtXD2IcyWPXoZA4R4KwXz-RzgOY,10683
|
|
25
|
+
tests/e2e/test_speculative_decoding.py,sha256=tj3VSJEi7r9aHjywZanlmfY4eS5Tfr5zPe9TH3PW5EY,9911
|
|
26
|
+
tests/e2e/test_structured_decoding.py,sha256=QYh9WjGrzm7syeLrGUawA6cOkWlQqVpTn7W6qwt65NY,1863
|
|
27
|
+
tests/executors/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
28
|
+
tests/executors/test_ray_distributed_executor.py,sha256=rMazBfirGsehEUXgpIPJkw0z7xO4cnK2kzcgxjFA6Bo,8435
|
|
29
|
+
tests/experimental/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
30
|
+
tests/experimental/test_llama3_jax_stashed.py,sha256=Ruypll_7QQOdjPmF0vDL_JVk41AHnULWuJtlgscSuZQ,8126
|
|
31
|
+
tests/kernels/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
32
|
+
tests/kernels/fused_moe_v1_test.py,sha256=cnMuvS_PD29F6whBioZlikWFrDXOMHwVPdSu2x-OJR0,10978
|
|
33
|
+
tests/kernels/gmm_test.py,sha256=rWE5fnp6hAV1FaGHjHjfScfIcoHuQ5wMdRGzhjt6Qnc,6820
|
|
34
|
+
tests/kernels/mla_v1_test.py,sha256=Rmhk8jHWeXwZmouza0o_z4NqAaac5mEo9lN1ychln9I,16076
|
|
35
|
+
tests/kernels/quantized_matmul_kernel_test.py,sha256=9Q3ufAG6NY9jeEFcre_IY2JbwpQdYzzhMWbXb5yfY6Q,4796
|
|
36
|
+
tests/kernels/ragged_kv_cache_update_v2_test.py,sha256=A12DnEqB0WtAWsD6ruF49RC4zrFcFM7CrGomElxE7jU,11396
|
|
37
|
+
tests/kernels/ragged_paged_attention_kernel_v2_test.py,sha256=1SSg9EzlLIdIQQw3BMoaEWbHVp30XY2A3FQS85ot4ss,11915
|
|
38
|
+
tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py,sha256=Ugs1DBBC-ZUuhQBomqGIqUKNiawqD539Rr1BqyNaqUQ,17007
|
|
39
|
+
tests/kernels/ragged_paged_attention_kernel_v3_test.py,sha256=HS60dynUGT096wCkkau4W3KJQyEQyB06P4j0LLd9-RA,15524
|
|
40
|
+
tests/kernels/collectives/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
41
|
+
tests/kernels/collectives/all_gather_matmul_kernel_test.py,sha256=ftp3CMoqiZdzD8vH0P9vNaiJx7FUICKUyxLduTqcsTk,2383
|
|
42
|
+
tests/layers/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
43
|
+
tests/layers/common/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
44
|
+
tests/layers/common/test_attention_interface.py,sha256=ke6h-e8CP-FhNY_ojKCYwyHgYG8aSvik1cEjCGH3VRk,5063
|
|
45
|
+
tests/layers/common/test_quantization.py,sha256=JcwDrNTm6UlBSV3s3mwwvpxOjqBpZDJwnYYoj3DnS7A,5344
|
|
46
|
+
tests/layers/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
47
|
+
tests/layers/jax/test_layers.py,sha256=L1xh_wniBtlfudya_WRmHUWOhEno0i6ikKE1XiBtaZs,5010
|
|
48
|
+
tests/layers/jax/test_qwix.py,sha256=V8MpFKJb5_evs-Z4WeZ5SxA-KAyFD6Qrex7ExywLxmE,39744
|
|
49
|
+
tests/layers/jax/test_rope.py,sha256=0biwYRSRsKMaRHknc8v8Tfrt0bmJKQGeQLPqR_D04mM,3565
|
|
50
|
+
tests/layers/jax/test_sharding.py,sha256=Hk1MWhIluOKIBx7-O9fKa1n6fF3SW7UMYsRI9AGzp_0,5914
|
|
51
|
+
tests/layers/jax/test_transformer_block.py,sha256=Wpgowc0ZJnv1GUxcK-Op6CCYWjpqgUM0p3EANk-YWzc,5742
|
|
52
|
+
tests/layers/jax/attention/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
53
|
+
tests/layers/jax/attention/test_common_attention.py,sha256=gXixLH2HosBp86PVwhRvwrTVVj4tl54VjrOCovwmmqM,3845
|
|
54
|
+
tests/layers/jax/attention/test_deepseek_v3_attention.py,sha256=hKxrUu4E8yfhIPj5V29p16xQxOXDvEQDzBZpyiAya3o,9292
|
|
55
|
+
tests/layers/jax/attention/test_llama4_attention.py,sha256=t1Kj0oTSFj_cVNuLl-ceZ-BY91sjx04xNRg_Epxjank,4980
|
|
56
|
+
tests/layers/jax/moe/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
57
|
+
tests/layers/jax/moe/test_deepseek_moe.py,sha256=2v7o2Svz1z6LH9tNqbL7dZtu5PSuKGiJzUccE-AMUYc,10550
|
|
58
|
+
tests/layers/jax/sample/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
59
|
+
tests/layers/jax/sample/test_rejection_sampler.py,sha256=qHvFpm-Oo6ZO0KHBN6nCB00BinbpCqxlg_QsSkAX-cI,65362
|
|
60
|
+
tests/layers/jax/sample/test_sampling.py,sha256=oCgI2YBnz5NCdwr2CWsiEFkddXnke1_S1tAIFP7D1oc,4098
|
|
61
|
+
tests/layers/jax/sample/test_sampling_metadata.py,sha256=WQCmgGkkn7sgBL9Uq7REdAkTUXq9YhbhBeuMTFtSIe8,9198
|
|
62
|
+
tests/layers/vllm/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
63
|
+
tests/layers/vllm/test_attention.py,sha256=NSbeKIi4eQj9RLiHeT-aEDvvsiHYbD3rk4uXq3_5_X8,13193
|
|
64
|
+
tests/layers/vllm/test_awq.py,sha256=khtLjyEO3wJlm3RM3eHVUtjAtB0BRtmmt57p-XfnFdA,14492
|
|
65
|
+
tests/layers/vllm/test_compressed_tensors_moe.py,sha256=Lu5M6lxFH7TetRxTNm3n6cT7su31idwZZi9MfNoP16s,7319
|
|
66
|
+
tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py,sha256=NivmHhqcSJE2NJVNYmndxldbA5T7FxMA6gBnz2EkPGo,16301
|
|
67
|
+
tests/layers/vllm/test_compressed_tensors_w8a8_int8.py,sha256=VHcCCOD1qlZst4DaBJ6vZ3PUL6n4LLFpwX9C5FKuLBY,16691
|
|
68
|
+
tests/layers/vllm/test_fp8.py,sha256=ZvFTg4Umgg6W2RwElkIZ_Rls_XZJ8sEW7yww2K3ztf4,666
|
|
69
|
+
tests/layers/vllm/test_mxfp4.py,sha256=ZOWZcBZvZV70EsrKQziBVo6hstJ9wNO3LbjQOtaKlHY,12175
|
|
70
|
+
tests/layers/vllm/test_unquantized.py,sha256=RvjImwpWaD7ZD6IhdeTwneRAtv0eTe22Qg84TMpc-ls,25095
|
|
71
|
+
tests/layers/vllm/utils.py,sha256=Qk67IqSrSovhPlWmDGFBr5vwgwtG7kcUzy69-oPgR0A,3105
|
|
72
|
+
tests/lora/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
73
|
+
tests/lora/conftest.py,sha256=OI4gPV4vNOCcfE93ccmIWQHd8-Gp9c2yGVlaSnuT4Tg,1559
|
|
74
|
+
tests/lora/test_bgmv.py,sha256=B1HCjh27379vCxZsd8nKMBZ8lr1JamuuWDgYiALyn18,1934
|
|
75
|
+
tests/lora/test_layers.py,sha256=TtIdl1SlMQ8afpkKbx6GRA9oRFAS8RjL7nqgAHxRtLM,26590
|
|
76
|
+
tests/lora/test_lora.py,sha256=Wqc6V7wQkobP-F8kHUkuMuiQYnxN775xlLUjDz6cEp0,5012
|
|
77
|
+
tests/lora/test_lora_perf.py,sha256=zcZud9Hexx6wa9qX0IvnjKyDD-i61NdIQrVO31Yx3vU,2381
|
|
78
|
+
tests/lora/utils.py,sha256=rY0tDZEZe58ye4-ykwrTnsiWuLcaEG57N_Rua90bDXI,2726
|
|
79
|
+
tests/models/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
80
|
+
tests/models/common/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
81
|
+
tests/models/common/test_model_loader.py,sha256=Sf-k_Kxdjkz-lS_0-ICfA4Yk2VXX33esP8PNG4B7FzA,17392
|
|
82
|
+
tests/models/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
83
|
+
tests/models/jax/test_deepseek_v3.py,sha256=9RY6ypfvPts3NOnvWu9n_T7pUjrvj_QY_saLOKpFg4c,16243
|
|
84
|
+
tests/models/jax/test_llama3.py,sha256=NYsT35yh9GzkYYcLcOo1BkBGGr14E89GtdCJJ6SFhI8,6610
|
|
85
|
+
tests/models/jax/test_llama4.py,sha256=MMQzTymnVUdWZ6XoOD8k9Q2ikmAk6tFSGB1C5DCi7pw,12605
|
|
86
|
+
tests/models/jax/test_llama_eagle3.py,sha256=DCk1ae9SLJUrqyx7uvNOmpqAAM09xb0rYNOst-Leo_M,7777
|
|
87
|
+
tests/models/jax/test_llama_guard_4.py,sha256=w-8cKwuTRFyzDh2mxvAofrt5xUprZyqRm5DRVRamGwE,9322
|
|
88
|
+
tests/models/jax/test_qwen2.py,sha256=xylG-LmHBSy76V-Yl5KiAXogpZPM2w3Mx0E61Ud5sO4,6227
|
|
89
|
+
tests/models/jax/test_qwen2_5_vl.py,sha256=PfB_gecAvXNrksxt8E56yP6d8ioZZWMoUIvh-OrbzJ4,26299
|
|
90
|
+
tests/models/jax/test_qwen3.py,sha256=NWLAZPwGIhZjW0OADk4JqU4ZPn8JGSGPwkbTQvKEc50,6021
|
|
91
|
+
tests/models/jax/test_weight_loading.py,sha256=RlmByQcjrsefybeNlS9wnL522be6CSR7YLcb7O5eZ-A,5205
|
|
92
|
+
tests/models/jax/utils/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
93
|
+
tests/models/jax/utils/test_multi_modal_utils.py,sha256=xrD8GijHGzb-n6z1W0okdjdNfREC1A9ZU7FQcbrx8zM,7867
|
|
94
|
+
tests/platforms/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
95
|
+
tests/platforms/test_tpu_platform.py,sha256=L0WUMncWzlWYWPAbtrE6Lhj-BuSjq-Ml2iKIjlmFGFE,2149
|
|
96
|
+
tests/runner/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
97
|
+
tests/runner/test_block_table.py,sha256=gFGF425mpWfOLjnQeQiG18TqFko8vpilJ3AiiiV1j8Y,14732
|
|
98
|
+
tests/runner/test_input_batch.py,sha256=7nEkB00JrhaKCKf1ep28iedYbNbuqEdaQAxYqHaXThc,8198
|
|
99
|
+
tests/runner/test_kv_cache.py,sha256=TvxmJNI8lM0ZNllZonHySA8NCQZ7prBgNODpYEI787E,7394
|
|
100
|
+
tests/runner/test_kv_cache_manager.py,sha256=dYVWQamfGwqytnumfvjRt2r3n9BRBqcSbCXGWnw1SXs,22461
|
|
101
|
+
tests/runner/test_multimodal_manager.py,sha256=8RbHHMvRuHg1Scc0b70tsr-tF2lfk8SZVx3InVgIryc,18591
|
|
102
|
+
tests/runner/test_persistent_batch_manager.py,sha256=EW6P-BtI4i59Clx-Lh84fU1GtDKF3Av2gtO-rCRYN_k,3148
|
|
103
|
+
tests/runner/test_speculative_decoding_manager.py,sha256=HgemtiBL_VhBheUgem3OpPj6yBK9vdJsL8VCABQdGXw,16093
|
|
104
|
+
tests/runner/test_structured_decoding_manager.py,sha256=pVX3z2TLR6SfBoEyRtv0BPajHbMVdcOAe4opMoxEpps,9802
|
|
105
|
+
tests/runner/test_tpu_runner.py,sha256=H1RjGGvNPfNNhglbiUs9J2QsokXaDtnmmtdoYRvA5_8,11649
|
|
106
|
+
tests/runner/test_tpu_runner_dp.py,sha256=TAEmI-JaIodgYNjjjQAAQg-q0bSbeVON5ZZE2jngfOk,50851
|
|
107
|
+
tests/runner/test_tpu_runner_mesh.py,sha256=kDyjdnd0vO4GQrcOAPLr9TEYA49-qDFE4gHt9IL6wlk,8638
|
|
108
|
+
tests/runner/test_utils.py,sha256=_R2bnKttqgg7vfPXP0Qfx38mr-4UBm2UMIbuQFAwgWk,15442
|
|
109
|
+
tests/spec_decode/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
110
|
+
tests/spec_decode/test_eagle3.py,sha256=18GbBKaMipCekyZMn24Fp-lraGEiASj2t-blohqWu7Y,12945
|
|
111
|
+
tests/worker/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
112
|
+
tests/worker/tpu_worker_test.py,sha256=lfRMW_DG2f9juR0I60uW682iDa9QvLNdtU-VLfJPUdY,17520
|
|
113
|
+
tpu_inference/__init__.py,sha256=2LJVEi6eR-RWHifo68n6D0SKYgg1NLrruW_E7Lz3oxg,2879
|
|
114
|
+
tpu_inference/env_override.py,sha256=pmL7lfs_rGCP92ya3wuWuudsCYeOMZ6tFZY82A4KkQc,365
|
|
115
|
+
tpu_inference/envs.py,sha256=A1Bdm5qiXhTdu-Q_yNzBpi79_nOJIDbdFF7MAMqmjxo,6662
|
|
116
|
+
tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
|
|
117
|
+
tpu_inference/tpu_info.py,sha256=lty-ngN1uUvQLlFGkWa2u5eEb5anwmcv_uyI0S95PdY,2840
|
|
118
|
+
tpu_inference/utils.py,sha256=0fQXcZJ4IiPGlNv_bLdkla5FeEEKEzyTsSDH-y47ouo,10641
|
|
119
|
+
tpu_inference/core/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
120
|
+
tpu_inference/core/core_tpu.py,sha256=WDD3koE_j1QhWS2BbMA2aQOZayPZm4tYPvzL4YCX2jY,33294
|
|
121
|
+
tpu_inference/core/disagg_executor.py,sha256=HZpgYMVxRxm0RQxO4l8IDYBWJ6Z3Tac6xavc5otcirc,4657
|
|
122
|
+
tpu_inference/core/disagg_utils.py,sha256=lv8MAVoAjtcmTaenUXVokg2q3d0tzsma86UiQlQ3omY,1492
|
|
123
|
+
tpu_inference/core/sched/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
124
|
+
tpu_inference/core/sched/dp_scheduler.py,sha256=-7d2zopJ5ZJFIJ8LbHsm_4bBBtP7qrim4XWVPDF6vrg,34960
|
|
125
|
+
tpu_inference/distributed/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
126
|
+
tpu_inference/distributed/jax_parallel_state.py,sha256=xMK0tEtblh37_LoHvp1-6qPI8AgX4HkE0ATuc7fdHKs,2798
|
|
127
|
+
tpu_inference/distributed/tpu_connector.py,sha256=3rR0y2P1MOOSM8nBfvl95ZQcVKMms3rL8zTdnxUmSms,29946
|
|
128
|
+
tpu_inference/distributed/utils.py,sha256=8pTkqI81b7Gkurn6M4zepoTUmTRaab3kfrH4ncAf5ns,3738
|
|
129
|
+
tpu_inference/executors/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
130
|
+
tpu_inference/executors/ray_distributed_executor.py,sha256=vz82tLPkQqwwUmwny1em_PrjNFZuroQPnXaEQAC5iWY,16980
|
|
131
|
+
tpu_inference/experimental/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
132
|
+
tpu_inference/experimental/llama3_jax_stashed.py,sha256=39XTuG-0C5pZe1oDznm6iCrvccZ_2CnC488YsvhxIho,11488
|
|
133
|
+
tpu_inference/kernels/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
134
|
+
tpu_inference/kernels/collectives/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
135
|
+
tpu_inference/kernels/collectives/all_gather_matmul.py,sha256=TtQWY0lNj8699JwDmjqbRrdku-3oAw5WkuuoFPS49AY,27597
|
|
136
|
+
tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py,sha256=OEPf4q08IeIFyJfzizgRs6kSD7w35NeZDRIn7CcZ344,1468
|
|
137
|
+
tpu_inference/kernels/collectives/util.py,sha256=LbLD6lOxuszbUsykF89gWQqEJUICCZsfzam3EJDPnFE,1859
|
|
138
|
+
tpu_inference/kernels/flash_attention/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
139
|
+
tpu_inference/kernels/flash_attention/kernel.py,sha256=n8gmAFVfchMXlyaSEj8xXJm6AadFt26edQihPRdithY,25897
|
|
140
|
+
tpu_inference/kernels/fused_moe/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
141
|
+
tpu_inference/kernels/fused_moe/v1/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
142
|
+
tpu_inference/kernels/fused_moe/v1/kernel.py,sha256=B0qWaa5vphIa3MJmeTbvpBMh9JJlRWNpmoORrz79Cvk,64990
|
|
143
|
+
tpu_inference/kernels/megablox/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
144
|
+
tpu_inference/kernels/megablox/common.py,sha256=CoJPNom6anJU9B4i05d2skytJEvNS994DYo0eEyVGuY,1639
|
|
145
|
+
tpu_inference/kernels/megablox/gmm.py,sha256=rVW70SGPshR9XvHiwzmskX4_yeD4nE8or3RfabwcCLM,24240
|
|
146
|
+
tpu_inference/kernels/mla/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
147
|
+
tpu_inference/kernels/mla/v1/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
148
|
+
tpu_inference/kernels/mla/v1/kernel.py,sha256=oovjb0x3qz08IL_KVjLLbNbcEcFXip55fqgIgfnl3RA,49758
|
|
149
|
+
tpu_inference/kernels/quantized_matmul/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
150
|
+
tpu_inference/kernels/quantized_matmul/kernel.py,sha256=-A9Kd2ApHWgPvCaUPfjM5JooLz_iCfWV1UT0taaZaAo,16264
|
|
151
|
+
tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py,sha256=3zhIm73JEE8qOty2_0v3AJlVz13k6qMB5wlXBDyC1EM,35130
|
|
152
|
+
tpu_inference/kernels/quantized_matmul/util.py,sha256=rf6nIiAj9I2cj4LDvtaZGhcLXEc94o2xgMWasnFaREM,1943
|
|
153
|
+
tpu_inference/kernels/ragged_paged_attention/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
154
|
+
tpu_inference/kernels/ragged_paged_attention/v2/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
155
|
+
tpu_inference/kernels/ragged_paged_attention/v2/kernel.py,sha256=462jgsWdnaQfO9K1Y99cJ-qidYWXZMc5GdoY9enQEWY,35019
|
|
156
|
+
tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256=y9-C_F28WGd282Ra_DqwTbHyUIIj2jyWY3DiX8yozHY,11080
|
|
157
|
+
tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
|
|
158
|
+
tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
159
|
+
tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=HVTQ4LJiEkWiYuUV1ey-2K2u6IULjJQ2dbX3qpo3FLA,60593
|
|
160
|
+
tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=VutC0CwPfF-luuRSPv6b7QiFt2EBiCPdoTMtOrFFZtI,60391
|
|
161
|
+
tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=sG67fBe8ckXdfvO7c9gfGFhu6_8owir8ZE6IOyHhNFY,231477
|
|
162
|
+
tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py,sha256=WusgnI6oDRsUoF8lp4vsaPepKO8oTJLlPSlLDpr3-7Y,25025
|
|
163
|
+
tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=VVYHEHmANvEddEKx8IPTRSXDykwzEOJa2GZKNv7nwnM,1755
|
|
164
|
+
tpu_inference/layers/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
165
|
+
tpu_inference/layers/common/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
166
|
+
tpu_inference/layers/common/attention_interface.py,sha256=WNEXNj1_6mDNS4KDXJRu9hkbJmKFlsp78txbqbDWhTo,13712
|
|
167
|
+
tpu_inference/layers/common/attention_metadata.py,sha256=rmipY517sefHe4owxC5USkm4lbL4zd4LZKokDYGECQo,1425
|
|
168
|
+
tpu_inference/layers/common/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
|
|
169
|
+
tpu_inference/layers/common/fused_moe_gmm.py,sha256=xzrFK1fRZXsF_a1robY1qe5I9rQ3t2kcjhN4KHmt75Q,19862
|
|
170
|
+
tpu_inference/layers/common/quant_methods.py,sha256=SCm9g7bE02XSMONmOCuT0vfHeTP6RzGQ57aTj919HgM,772
|
|
171
|
+
tpu_inference/layers/common/quantization.py,sha256=63-kb4XR3D1mCryBYhRy881W2X52m7kF_CmHeETo2R8,9216
|
|
172
|
+
tpu_inference/layers/common/sharding.py,sha256=curCejZPj8ND4rxjWEbwRozkFYlK_HlpIyTywhDHcWU,26171
|
|
173
|
+
tpu_inference/layers/common/utils.py,sha256=k1OWrJJI6E58TCNUXO7TFc5l_9XmwL3d7N2U4QE-zPs,4417
|
|
174
|
+
tpu_inference/layers/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
175
|
+
tpu_inference/layers/jax/base.py,sha256=UhT4ut_59ynUPdaZGpMPSCQkPTWXA9BxkaPy7lDhoLI,6350
|
|
176
|
+
tpu_inference/layers/jax/constants.py,sha256=YQJOeAbja1yTbPhoOWMp24OF1RCMwPybK1NIwPrrYJ0,3329
|
|
177
|
+
tpu_inference/layers/jax/layers.py,sha256=elv04eCMFj5Jt3SF0PXxyuQPTwmJDgsuvZ9oK88HTso,11208
|
|
178
|
+
tpu_inference/layers/jax/misc.py,sha256=Jdxv8SAT1yVuM_1_lGWImRSXlu2xGLnXI-TRGRNsBYw,1141
|
|
179
|
+
tpu_inference/layers/jax/pp_utils.py,sha256=gP3Xt-Pinm6E7yJ9jtsSnmmoz9GmgBN83TkSgIrz0OA,1726
|
|
180
|
+
tpu_inference/layers/jax/rope.py,sha256=FbZKJPd9T0IDaZyOJkrFl2CL1on1womCzZBiUPLU0O4,11924
|
|
181
|
+
tpu_inference/layers/jax/rope_interface.py,sha256=cPqVpKG5_SU7S7xcrMEaPBJLqi1nC4uMN-1S-dmb0mQ,8950
|
|
182
|
+
tpu_inference/layers/jax/transformer_block.py,sha256=HTI0fYPQd23UbnJSB_pL2K3un3q_i3guvJiNCUReVRs,4492
|
|
183
|
+
tpu_inference/layers/jax/attention/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
184
|
+
tpu_inference/layers/jax/attention/attention.py,sha256=_N5W4ox8EzC1CZYcIhsEi35X8WCIMFEBlSzVtDDcTu8,10623
|
|
185
|
+
tpu_inference/layers/jax/attention/deepseek_v3_attention.py,sha256=KP-hgck-wTzTcwDNB08DwNiqsE-6OD4tQ1jLVwWQvEw,22427
|
|
186
|
+
tpu_inference/layers/jax/attention/gpt_oss_attention.py,sha256=EM1kJpr77VHh95aSD5UnSJazB_anS_7PyaD8TixVMrY,9241
|
|
187
|
+
tpu_inference/layers/jax/attention/llama4_attention.py,sha256=QzBDoEioI9mMdI1T2LNlsr89iaGl234e-9s202YWS8M,6713
|
|
188
|
+
tpu_inference/layers/jax/moe/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
189
|
+
tpu_inference/layers/jax/moe/deepseek_v3_moe.py,sha256=5j6TJO8fAB2Yv6mVAeM2F9WLe4QDM9bf6zxtdKjHjCQ,26456
|
|
190
|
+
tpu_inference/layers/jax/moe/gpt_oss_moe.py,sha256=-uliFqHJFOTT9WJCEpGhkImOXMSoo3aePXMOmKXlgmk,6771
|
|
191
|
+
tpu_inference/layers/jax/moe/moe.py,sha256=E7L8bJucTVke89o048GAbWdtuQIL5oDz-MkW0NK4E00,10114
|
|
192
|
+
tpu_inference/layers/jax/sample/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
193
|
+
tpu_inference/layers/jax/sample/rejection_sampler.py,sha256=VqN0mxi7Xg58w4EXS625ndC8NyA_UZMV9bjFM1mkvrY,21000
|
|
194
|
+
tpu_inference/layers/jax/sample/sampling.py,sha256=IfJBFSXuTdd0QELn8Opmh7HgdzKreIwGYUOskTFp4aI,3888
|
|
195
|
+
tpu_inference/layers/jax/sample/sampling_metadata.py,sha256=bip7TQcw-VHyN6072zBQY-tA0-QTyJpnuYg04mw9Sv0,3136
|
|
196
|
+
tpu_inference/layers/vllm/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
197
|
+
tpu_inference/layers/vllm/attention.py,sha256=LMQbS2KAup0Q-mmN5pzV6uUs-qdGpTSH8eV6ByHde9g,7370
|
|
198
|
+
tpu_inference/layers/vllm/fused_moe.py,sha256=E4JeuCekVYsvMLJkccOrP690GL2Q_EWlLwW3ZK5NT-0,4013
|
|
199
|
+
tpu_inference/layers/vllm/linear.py,sha256=KRScVrEGys3NLpDzG0UieHb371UJR1R_ct6LR84_-iE,2428
|
|
200
|
+
tpu_inference/layers/vllm/process_weights/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
201
|
+
tpu_inference/layers/vllm/process_weights/cleanup_sharding.py,sha256=vg9PdWdY8caYKBs5G_YhkKA4MdAR213knoLv_TJlyDU,9793
|
|
202
|
+
tpu_inference/layers/vllm/process_weights/fused_moe_weights.py,sha256=vVrzVzrJ6_vUMPI_Nzqmqco2yeZb9O3CEzNII2rXWU0,14936
|
|
203
|
+
tpu_inference/layers/vllm/process_weights/linear_weights.py,sha256=3Qx-Dgdx5Khjb9B0LXmFVUz7Tc8bXf6esSfk7MWicwM,6068
|
|
204
|
+
tpu_inference/layers/vllm/quantization/__init__.py,sha256=r9oDaXh0TiDSnh2WOWEYfPDRaH3aU9uW2ANHrezZZjw,2450
|
|
205
|
+
tpu_inference/layers/vllm/quantization/awq.py,sha256=5HdRtJ1E5adCKmDIlPkIzXdgdBsSakrmRPKnQjryEwk,8595
|
|
206
|
+
tpu_inference/layers/vllm/quantization/configs.py,sha256=0q-gRrR7sxgUty1OzmIc6MrMH9dpuN_DYHISskvlpk8,4925
|
|
207
|
+
tpu_inference/layers/vllm/quantization/fp8.py,sha256=z4xXpqy7I37p6rBZjlCQRomFQzbWHOw1xWkHN3_bndw,4541
|
|
208
|
+
tpu_inference/layers/vllm/quantization/mxfp4.py,sha256=q7EnVQlbdTy_qicmRo_mn6t5Q3fEt_cs31SUUVga8hU,8597
|
|
209
|
+
tpu_inference/layers/vllm/quantization/unquantized.py,sha256=YFZHAjmrjWnuZuwx-lG0Eka9BNqCvIq5kNbEY6vAn3Y,10795
|
|
210
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
211
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=VuEqI7HpN39Xee-z5ohuqlu9PdlcBpFJpfe79PsJhx0,5930
|
|
212
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py,sha256=8r1dT0UexEQD9-4kGiky1x7ITVpMPU90bzs-6HZQ51E,7841
|
|
213
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
214
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py,sha256=W84yM33UCkCF_AZRNCoPGLqFI_EO2WHLcCfzx5TWzl4,9529
|
|
215
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=c5lqLSg-6u6Y56XYH9m1-20hlmNQ_zIB832NXDLJWJ4,6816
|
|
216
|
+
tpu_inference/lora/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
217
|
+
tpu_inference/lora/torch_lora_ops.py,sha256=YR3Hj8nLLiQ-6wXy4uFsjQxFTbJYZ4o5dh_L0mlXg-o,3261
|
|
218
|
+
tpu_inference/lora/torch_punica_tpu.py,sha256=qTnXZGLoOgvukSxeunO_SfpPTlkq9GlMj9H7zVYg9LE,12680
|
|
219
|
+
tpu_inference/models/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
220
|
+
tpu_inference/models/common/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
221
|
+
tpu_inference/models/common/model_loader.py,sha256=gSaY_PCRtVjx-lKsNROGmgR41E_oMba2dVxtQONADvI,21878
|
|
222
|
+
tpu_inference/models/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
223
|
+
tpu_inference/models/jax/deepseek_v3.py,sha256=mje3RgxE1NwKWVLgJnPq3ebWB1J8T6YGHT2TtxN10Dg,45031
|
|
224
|
+
tpu_inference/models/jax/gpt_oss.py,sha256=bgdsCx3UcTqEJatWBYbma5HNHH8GEaHN4aL5IsAeSmM,21592
|
|
225
|
+
tpu_inference/models/jax/jax_intermediate_tensor.py,sha256=XKpDgPkOiRtYaPrW76ILxcp2uFfSiE1JMdqHWGo0-Ss,3179
|
|
226
|
+
tpu_inference/models/jax/llama3.py,sha256=FjTGC69V_EJmvb5BIqYu3V5NS1Pvy-5Pb34kMn5YU5U,16317
|
|
227
|
+
tpu_inference/models/jax/llama4.py,sha256=Ssycb5fcGjhJYg8FfcNckVhow7bvVt0FJbbpHinzMAA,30206
|
|
228
|
+
tpu_inference/models/jax/llama_eagle3.py,sha256=_wnljvb8lLCQ0Z3Vuw0QI7F6b41x6I1WuvstZWGvCYE,13051
|
|
229
|
+
tpu_inference/models/jax/llama_guard_4.py,sha256=R4wo45s1JsVD39t8JeAItujGoi-sl43HBH95hr7qEVw,15845
|
|
230
|
+
tpu_inference/models/jax/qwen2.py,sha256=bart2yYGv0J-lNbk8Hk5jn5IF6j_Jp8YKSEjwVU_y24,14038
|
|
231
|
+
tpu_inference/models/jax/qwen2_5_vl.py,sha256=3g3tUt7c83fKOdiMzuq2VyldCyeXoCBGrVYfqyIWwGE,50370
|
|
232
|
+
tpu_inference/models/jax/qwen3.py,sha256=jVOOVrBFnxRIZ_Euo90iCga8rORpz0Kqs79uKqsFwEQ,11678
|
|
233
|
+
tpu_inference/models/jax/utils/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
234
|
+
tpu_inference/models/jax/utils/file_utils.py,sha256=8iZcGNvF1N0gNioH8fBlVYTSGYn4fC2WvmlTyeDZyZM,3415
|
|
235
|
+
tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=c2LRXdOPi3F779yg2UX-DnuFDxF1JciTcFa09iODxZs,6695
|
|
236
|
+
tpu_inference/models/jax/utils/weight_utils.py,sha256=0xyjGlDSrA09gtb4plw9yX57VPMgn3o5WNl6mXPDU70,23121
|
|
237
|
+
tpu_inference/models/jax/utils/qwix/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
238
|
+
tpu_inference/models/jax/utils/qwix/qwix_utils.py,sha256=w3wmDb1drJxOK1mVRVMORznqKbtZqFfi7H0Ib_k-iW8,29526
|
|
239
|
+
tpu_inference/models/vllm/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
240
|
+
tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=mqD0qSnRY28CJ-ZU9BLXPD4zcMui0_P2vBZsCn2KWTs,13053
|
|
241
|
+
tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=vsXQnC2aZ_mHKb-7d9UeN28lfawfApNTm5asUMgEhgo,1762
|
|
242
|
+
tpu_inference/platforms/__init__.py,sha256=BK6rwAhiqVSAUJ9m9EehSKetA6hEPe92flD9Ei076WQ,649
|
|
243
|
+
tpu_inference/platforms/tpu_platform.py,sha256=loDc6hi9DlBmcoN6CjuEt6GKYL7tXY29D086s00_M4o,9474
|
|
244
|
+
tpu_inference/runner/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
245
|
+
tpu_inference/runner/block_table.py,sha256=K3Ic8EgPM08d_C5nEN60mxoRydlaQWySAemf_8Q_qVw,4175
|
|
246
|
+
tpu_inference/runner/compilation_manager.py,sha256=BFjOzJUyEJTmUZAvGCm3yeqoY7Kkw2JKc_A3CzRoN7o,42112
|
|
247
|
+
tpu_inference/runner/input_batch.py,sha256=bx221NX2IOWzrtopss-B-2ZKW4y-U6nQpG09PjpUziw,18273
|
|
248
|
+
tpu_inference/runner/kv_cache.py,sha256=xpB6VTrT3lIq5JNNPJTVEnHFgehIzgxKNIHxxXIxwKI,6046
|
|
249
|
+
tpu_inference/runner/kv_cache_manager.py,sha256=u6pXaWPzmPe34lXiy-acAdGBmp9WEQrGvksyBfGBRdM,23342
|
|
250
|
+
tpu_inference/runner/lora_utils.py,sha256=LgnrePvkBFyMvQqSp9VfrIbWPBwpWG4_iUaj3lX0Os8,4448
|
|
251
|
+
tpu_inference/runner/multimodal_manager.py,sha256=sNzj_U4XTRQtuslKljxbcS6NRNlFB_bN6l0qpnqrlfM,10315
|
|
252
|
+
tpu_inference/runner/persistent_batch_manager.py,sha256=aCeTyqCgBnQy_6hXjiNLtF81ekG0-YwlQiWeJhx-pdM,13838
|
|
253
|
+
tpu_inference/runner/speculative_decoding_manager.py,sha256=-eSxTIGXbRWRZjHJfikb7kfqbtr_cj7Pca9zInWSn1w,10790
|
|
254
|
+
tpu_inference/runner/structured_decoding_manager.py,sha256=sj1fPrit0qdhcQtDbue5kpxos7zL16_dZQ5YSXTDbzg,4148
|
|
255
|
+
tpu_inference/runner/tpu_runner.py,sha256=cgIyZiI3UjpvPWhNRL-mCSnssbbDNt00g5idAzwgWR0,80736
|
|
256
|
+
tpu_inference/runner/utils.py,sha256=lKqL5nxGTk7ufzJRNdp4udn2bPu3jIX52W7akXgSrHc,17133
|
|
257
|
+
tpu_inference/spec_decode/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
258
|
+
tpu_inference/spec_decode/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
259
|
+
tpu_inference/spec_decode/jax/eagle3.py,sha256=5WtEbkgzXpmFz374ibQD5IIcRro4d0SNeCYgBv2nM1c,19678
|
|
260
|
+
tpu_inference/worker/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
|
|
261
|
+
tpu_inference/worker/tpu_worker.py,sha256=ntwCibPyiw-z8aMUdtu8usqU_q2b0u7diWNOmpjG_6o,21651
|
|
262
|
+
tpu_inference-0.13.2.dev20251230.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
263
|
+
tpu_inference-0.13.2.dev20251230.dist-info/METADATA,sha256=08-onD7oUGsgmWyILrp51XmacHdKXu1X824ws4eoh88,5767
|
|
264
|
+
tpu_inference-0.13.2.dev20251230.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
265
|
+
tpu_inference-0.13.2.dev20251230.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
|
|
266
|
+
tpu_inference-0.13.2.dev20251230.dist-info/RECORD,,
|
|
@@ -1,208 +0,0 @@
|
|
|
1
|
-
from typing import Optional, Union
|
|
2
|
-
|
|
3
|
-
import jax
|
|
4
|
-
import jax.numpy as jnp
|
|
5
|
-
import torch
|
|
6
|
-
from jax.experimental.shard_map import shard_map
|
|
7
|
-
from jax.sharding import Mesh, NamedSharding
|
|
8
|
-
from jax.sharding import PartitionSpec as P
|
|
9
|
-
from torchax.interop import torch_view
|
|
10
|
-
from torchax.ops.mappings import t2j
|
|
11
|
-
|
|
12
|
-
from tpu_inference import envs
|
|
13
|
-
from tpu_inference.kernels.quantized_matmul.kernel import (
|
|
14
|
-
quantized_matmul_kernel, xla_quantized_matmul)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
|
|
18
|
-
mesh: Mesh, weight_sharding: P) -> jax.Array:
|
|
19
|
-
"""
|
|
20
|
-
Wrapper around the quantized matmul kernel.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
x: Activation.
|
|
24
|
-
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
25
|
-
w_s: Weight quantization scale. [n_output_features]
|
|
26
|
-
mesh: Mesh to shard on.
|
|
27
|
-
weight_sharding: PartitionSpec for the weight tensor.
|
|
28
|
-
|
|
29
|
-
Returns:
|
|
30
|
-
Output of the quantized matmul.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
# NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
|
|
34
|
-
# with the kernel and thus we disable it for now.
|
|
35
|
-
if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
|
|
36
|
-
out_axis, in_axis = weight_sharding
|
|
37
|
-
x_sharding = P(None, in_axis)
|
|
38
|
-
scale_sharding = P(out_axis, )
|
|
39
|
-
out_sharding = P(None, out_axis)
|
|
40
|
-
|
|
41
|
-
x = jax.lax.with_sharding_constraint(x,
|
|
42
|
-
NamedSharding(mesh, x_sharding))
|
|
43
|
-
|
|
44
|
-
def wrapper(x, w_q, w_s):
|
|
45
|
-
output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
|
|
46
|
-
if in_axis:
|
|
47
|
-
output = jax.lax.psum(output, axis_name=in_axis)
|
|
48
|
-
return output
|
|
49
|
-
|
|
50
|
-
return shard_map(wrapper,
|
|
51
|
-
mesh=mesh,
|
|
52
|
-
in_specs=(x_sharding, weight_sharding,
|
|
53
|
-
scale_sharding),
|
|
54
|
-
out_specs=(out_sharding),
|
|
55
|
-
check_rep=False)(x, w_q, w_s)
|
|
56
|
-
else:
|
|
57
|
-
return xla_quantized_matmul(x, w_q, w_s)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
|
|
61
|
-
split_sizes: list[int],
|
|
62
|
-
n_shards: int, dim: int):
|
|
63
|
-
"""
|
|
64
|
-
Reorder a replicated concatenated tensor such that when sharded on multiple chips, each shard is a concatenation of the shards of the individual tensors.
|
|
65
|
-
For example, let the concatenated_tensor be:
|
|
66
|
-
AAAAAAAAAAAABBBBBBBBCCCC
|
|
67
|
-
12 As 8 Bs 4 Cs
|
|
68
|
-
and let the split_sizes = [12, 8, 4] and n_shards = 4.
|
|
69
|
-
The output is:
|
|
70
|
-
AAABBCAAABBCAAABBCAAABBC
|
|
71
|
-
In other words, it reorders the input tensor into 4 segements, with each segment corresponding to a shard and being AAABBC.
|
|
72
|
-
Args:
|
|
73
|
-
concatenated_tensor: the tensor, concatenated on the dimension specified by `dim`.
|
|
74
|
-
split_sizes: each individual tensor's size on the dimension specified by `dim`.
|
|
75
|
-
n_shards: num of shards.
|
|
76
|
-
dim: the dimension on which the concatenated_tensor is concatenated.
|
|
77
|
-
"""
|
|
78
|
-
# Split the concatenated tensor into individual tensors.
|
|
79
|
-
split_tensors = []
|
|
80
|
-
start_offset = 0
|
|
81
|
-
old_shape = concatenated_tensor.shape
|
|
82
|
-
# New shape ensures each split_tensor[i] maps to a tensor in ith shards
|
|
83
|
-
new_shape = old_shape[:dim] + (n_shards, -1) + old_shape[dim + 1:]
|
|
84
|
-
for split_size in split_sizes:
|
|
85
|
-
split_tensor = jax.lax.slice_in_dim(concatenated_tensor,
|
|
86
|
-
start_offset,
|
|
87
|
-
start_offset + split_size,
|
|
88
|
-
axis=dim)
|
|
89
|
-
split_tensors.append(split_tensor.reshape(new_shape))
|
|
90
|
-
start_offset += split_size
|
|
91
|
-
# While maintaining 0th dim as a shard dim, we concatenate along 1th dim to
|
|
92
|
-
# to create concatenated tnensor where 0th dim maps to shard dim.
|
|
93
|
-
reordered_tensor = jnp.concatenate(split_tensors, axis=dim + 1)
|
|
94
|
-
return reordered_tensor.reshape(old_shape)
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def slice_sharded_tensor_for_concatenation(sharded_tensor: jax.Array,
|
|
98
|
-
split_sizes: list[int],
|
|
99
|
-
n_shards: int):
|
|
100
|
-
"""
|
|
101
|
-
Slice the input tensor which is sharded on multiple chips (on the last dim) into individual tensors with the same sharding.
|
|
102
|
-
For example, let the sharded_tensor be:
|
|
103
|
-
AAABBC | AAABBC | AAABBC | AAABBC
|
|
104
|
-
Shard0 Shard1 Shard2 Shard3
|
|
105
|
-
and let the split_sizes = [12, 8, 4] and n_shards = 4.
|
|
106
|
-
The output is a list of 3 tensors:
|
|
107
|
-
AAA | AAA | AAA | AAA
|
|
108
|
-
BB | BB | BB | BB
|
|
109
|
-
C | C | C | C
|
|
110
|
-
Shard0 Shard1 Shard2 Shard3
|
|
111
|
-
In other words, each individual tensor is a slice of the input tensor with the same sharding.
|
|
112
|
-
Args:
|
|
113
|
-
sharded_tensor: the input tensor, sharded on the last dim.
|
|
114
|
-
split_sizes: each individual tensor's size on the last dim.
|
|
115
|
-
n_shards: num of shards.
|
|
116
|
-
"""
|
|
117
|
-
new_shape = sharded_tensor.shape[:-1] + (n_shards, -1)
|
|
118
|
-
# New shape ensures each sharded_tensor[:, i] maps to a tensor in ith shards
|
|
119
|
-
sharded_tensor = sharded_tensor.reshape(new_shape)
|
|
120
|
-
|
|
121
|
-
split_tensors = []
|
|
122
|
-
start_offset = 0
|
|
123
|
-
for split_size in split_sizes:
|
|
124
|
-
assert split_size % n_shards == 0
|
|
125
|
-
sz = split_size // n_shards # size of this split tensor per shard
|
|
126
|
-
end_offset = start_offset + sz
|
|
127
|
-
# Because we are slicing over last dim, sharding dim remains intact.
|
|
128
|
-
# Therefore, splitting happens locally.
|
|
129
|
-
split_tensor = sharded_tensor[..., start_offset:end_offset]
|
|
130
|
-
split_tensors.append(split_tensor.reshape(new_shape[:-2] + (-1, )))
|
|
131
|
-
start_offset = end_offset
|
|
132
|
-
|
|
133
|
-
return split_tensors
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def torch_to_jax_param(
|
|
137
|
-
tensor: torch.Tensor,
|
|
138
|
-
sharding: NamedSharding,
|
|
139
|
-
output_sizes: Optional[int],
|
|
140
|
-
n_shards: int,
|
|
141
|
-
fused: bool,
|
|
142
|
-
dim: int = 0,
|
|
143
|
-
jax_dtype: Optional[jnp.dtype] = None,
|
|
144
|
-
) -> Union[torch.nn.Parameter, torch.nn.ParameterList]:
|
|
145
|
-
if output_sizes is None:
|
|
146
|
-
output_sizes = [tensor.shape[0]]
|
|
147
|
-
|
|
148
|
-
tensor = t2j(tensor, use_dlpack=False)
|
|
149
|
-
if jax_dtype:
|
|
150
|
-
tensor = tensor.astype(jax_dtype)
|
|
151
|
-
|
|
152
|
-
if fused:
|
|
153
|
-
tensor = reorder_concatenated_tensor_for_sharding(
|
|
154
|
-
tensor, output_sizes, n_shards, dim)
|
|
155
|
-
tensor = jax.device_put(tensor, sharding)
|
|
156
|
-
param = torch.nn.Parameter(torch_view(tensor), requires_grad=False)
|
|
157
|
-
else:
|
|
158
|
-
tensors = []
|
|
159
|
-
start_offset = 0
|
|
160
|
-
for size in output_sizes:
|
|
161
|
-
end_offset = start_offset + size
|
|
162
|
-
|
|
163
|
-
tensor_split = jax.lax.slice_in_dim(tensor,
|
|
164
|
-
start_offset,
|
|
165
|
-
end_offset,
|
|
166
|
-
axis=dim)
|
|
167
|
-
tensor_split = jax.device_put(tensor_split, sharding)
|
|
168
|
-
tensor_split = torch.nn.Parameter(torch_view(tensor_split),
|
|
169
|
-
requires_grad=False)
|
|
170
|
-
tensors.append(tensor_split)
|
|
171
|
-
|
|
172
|
-
start_offset = end_offset
|
|
173
|
-
param = torch.nn.ParameterList(tensors)
|
|
174
|
-
return param
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
MODEL_MATMUL_FUSION_TRUTH_TABLE = {
|
|
178
|
-
("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "QKVParallelLinear"):
|
|
179
|
-
True,
|
|
180
|
-
("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
|
|
181
|
-
False,
|
|
182
|
-
("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "QKVParallelLinear"):
|
|
183
|
-
False,
|
|
184
|
-
("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
|
|
185
|
-
False,
|
|
186
|
-
("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "QKVParallelLinear"):
|
|
187
|
-
False,
|
|
188
|
-
("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
|
|
189
|
-
False,
|
|
190
|
-
("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "QKVParallelLinear"):
|
|
191
|
-
False,
|
|
192
|
-
("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
|
|
193
|
-
False,
|
|
194
|
-
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "QKVParallelLinear"):
|
|
195
|
-
False,
|
|
196
|
-
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "MergedColumnParallelLinear"):
|
|
197
|
-
False,
|
|
198
|
-
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "QKVParallelLinear"):
|
|
199
|
-
False,
|
|
200
|
-
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "MergedColumnParallelLinear"):
|
|
201
|
-
False,
|
|
202
|
-
}
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
def get_model_matmul_fusion_assignment(model_name: str, batch_size: int,
|
|
206
|
-
tp_size: int, layer_name: str):
|
|
207
|
-
key = (model_name, batch_size, tp_size, layer_name)
|
|
208
|
-
return MODEL_MATMUL_FUSION_TRUTH_TABLE.get(key, True)
|
|
File without changes
|
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
-
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
# MXFP4 constants
|
|
7
|
-
MXFP4_BLOCK_SIZE: int = 32
|
|
8
|
-
# Exponent-only e8m0 scale bias used by MXFP4 scales
|
|
9
|
-
MXFP4_SCALE_BIAS: int = 127
|
|
10
|
-
# Name used in config.json quantization_config["quant_method"]
|
|
11
|
-
MXFP4_QUANT_METHOD: str = "mxfp4"
|
|
12
|
-
|
|
13
|
-
# Precompute a small LUT once; move to device on demand (cheap 16-element copy)
|
|
14
|
-
FP4_LUT = torch.tensor(
|
|
15
|
-
[
|
|
16
|
-
0.0,
|
|
17
|
-
0.5,
|
|
18
|
-
1.0,
|
|
19
|
-
1.5,
|
|
20
|
-
2.0,
|
|
21
|
-
3.0,
|
|
22
|
-
4.0,
|
|
23
|
-
6.0, # 0b0000-0b0111
|
|
24
|
-
-0.0,
|
|
25
|
-
-0.5,
|
|
26
|
-
-1.0,
|
|
27
|
-
-1.5,
|
|
28
|
-
-2.0,
|
|
29
|
-
-3.0,
|
|
30
|
-
-4.0,
|
|
31
|
-
-6.0, # 0b1000-0b1111
|
|
32
|
-
],
|
|
33
|
-
dtype=torch.float32)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def unpack_mxfp4(packed: torch.Tensor) -> torch.Tensor:
|
|
37
|
-
"""Unpack uint8 (..., 16) -> fp4 values (..., 32) using low->high nibble order.
|
|
38
|
-
|
|
39
|
-
Returns float32 values corresponding to FP4 codebook entries.
|
|
40
|
-
"""
|
|
41
|
-
assert packed.dtype == torch.uint8
|
|
42
|
-
low = packed & 0x0F
|
|
43
|
-
high = (packed >> 4) & 0x0F
|
|
44
|
-
idx = torch.stack([low, high], dim=-1).flatten(-2)
|
|
45
|
-
lut = FP4_LUT.to(packed.device)
|
|
46
|
-
return lut[idx.long()]
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def e8m0_to_fp32(u8: torch.Tensor) -> torch.Tensor:
|
|
50
|
-
"""Convert e8m0 uint8 exponents to power-of-two scales using MXFP4_SCALE_BIAS.
|
|
51
|
-
|
|
52
|
-
Uses ldexp for exact power-of-two scaling: 1.0 * 2**(u8 - bias).
|
|
53
|
-
"""
|
|
54
|
-
exponents = (u8.to(torch.int32) - int(MXFP4_SCALE_BIAS)).to(torch.int32)
|
|
55
|
-
ones = torch.ones_like(u8, dtype=torch.float32)
|
|
56
|
-
return torch.ldexp(ones, exponents)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def dequant_mxfp4_to_bf16(blocks_u8: torch.Tensor,
|
|
60
|
-
scales_u8: torch.Tensor) -> torch.Tensor:
|
|
61
|
-
"""Dequantize MXFP4 blocks/scales into bfloat16 values.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte holds 2 FP4 codes.
|
|
65
|
-
scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per 32-value block.
|
|
66
|
-
|
|
67
|
-
Returns:
|
|
68
|
-
torch.bfloat16 tensor with last logical dimension K = Kb * 32.
|
|
69
|
-
"""
|
|
70
|
-
if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
|
|
71
|
-
raise ValueError(
|
|
72
|
-
f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}"
|
|
73
|
-
)
|
|
74
|
-
# Unpack FP4 codes to float32 values [..., Kb, 32]
|
|
75
|
-
fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32)
|
|
76
|
-
# Compute power-of-two scales and apply per block
|
|
77
|
-
scales = e8m0_to_fp32(scales_u8).unsqueeze(-1) # (..., Kb, 1)
|
|
78
|
-
full = (fp4_vals * scales).reshape(*fp4_vals.shape[:-2],
|
|
79
|
-
fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
|
|
80
|
-
return full.to(torch.bfloat16)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def unpack_mxfp4_to_fp32(
|
|
84
|
-
blocks_u8: torch.Tensor,
|
|
85
|
-
scales_u8: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
86
|
-
"""Decode MXFP4 packed blocks and e8m0 scales to float32 codes and scales.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte packs two FP4 codes.
|
|
90
|
-
scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per block.
|
|
91
|
-
|
|
92
|
-
Returns:
|
|
93
|
-
(codes_fp32, scales_fp32), where
|
|
94
|
-
- codes_fp32 has shape [..., Kb*32] and dtype float32
|
|
95
|
-
- scales_fp32 has shape [..., Kb] and dtype float32
|
|
96
|
-
"""
|
|
97
|
-
if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
|
|
98
|
-
raise ValueError(
|
|
99
|
-
f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}"
|
|
100
|
-
)
|
|
101
|
-
fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32) float32
|
|
102
|
-
codes_fp32 = fp4_vals.reshape(*fp4_vals.shape[:-2],
|
|
103
|
-
fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
|
|
104
|
-
scales_fp32 = e8m0_to_fp32(scales_u8) # (..., Kb) float32
|
|
105
|
-
return codes_fp32, scales_fp32
|