tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,266 @@
1
+ tests/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
2
+ tests/test_base.py,sha256=47IflI4nIktZHlcmeqhmX9IdTKofg7OgsOiCyUTXlLw,7916
3
+ tests/test_envs.py,sha256=v0_R-HfWRNY8ssPqFrytHMl1irohJaTpS_rSKo2FZaY,10021
4
+ tests/test_tpu_info.py,sha256=OrA0Fbs9uCVqd8w7dqlGA_8KZArriyltqrCWf3hDDDU,5245
5
+ tests/test_utils.py,sha256=FF_41NL1VmUXDVvKr9eZg_juprqtHlUqSPR6Sisftdo,6309
6
+ tests/core/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
7
+ tests/core/test_core_tpu.py,sha256=r496rk1eOsK_F4nvm9zprl_T-RcO6eCUb7LuVReOZno,21413
8
+ tests/core/test_disagg_executor.py,sha256=QdE2YZs08EyDDCmSjhiXkXqQ9BJTgO6csr_E1xkkfSg,2256
9
+ tests/core/test_disagg_utils.py,sha256=A5icdqkJlau2PHYAxHfHKuqrlEKXVJu2nm02XOrXjcc,2530
10
+ tests/core/test_dp_scheduler.py,sha256=m6ph_OH9tXz6AxNde8cIjptd1lwDVSCqIV2Ef-cNJFk,34253
11
+ tests/core/test_init.py,sha256=5BDDC-dmDtWEGaBPjQSiYJuMiwTBVRSDx9p7Cv8DKyI,2262
12
+ tests/distributed/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
13
+ tests/distributed/test_distributed_utils.py,sha256=YXKbSG9J72vCrU5mPiFf1ya-Yzc1BjeahdBmQVez8Wc,5031
14
+ tests/distributed/test_tpu_connector.py,sha256=ajKeRUi3x29hQXfLrSlo6yDczpwZsg_mGt2vKBGRZdk,20538
15
+ tests/e2e/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
16
+ tests/e2e/test_async_scheduler.py,sha256=215xGuyTEBSOe-c1l48TIjrCqhbVFZY3m5p3q5mU7jA,6905
17
+ tests/e2e/test_data_parallel.py,sha256=KB-_BKic_iZyn4WbPWsUdVClinzd8g7PrQ0ui5B-nwo,10725
18
+ tests/e2e/test_hybrid_kvcache.py,sha256=Y7a-grjvAKBbp7vbQncVEQKGM1WxcwO0qa2o0opKiEI,8076
19
+ tests/e2e/test_local_disagg.py,sha256=xIjYI6RGA6bZk4dluklhfYBoJGbHkrSihSkJtPgpZv4,10434
20
+ tests/e2e/test_model_loader.py,sha256=DYlS420KXkNzeIijAf-0UQsYH0pOAGcXRl6P99PBiAc,9366
21
+ tests/e2e/test_multi_modal_inference.py,sha256=hVatj8Rra6XAekp6zBxRivQUcGiV8SimPph9cZ-TJyk,3896
22
+ tests/e2e/test_pipeline_parallel.py,sha256=VpxY9wgQj3-i0XooHZHdmHGdMS3ilmHbxu6ZfyQDUP0,9519
23
+ tests/e2e/test_runai_model_streamer_loader.py,sha256=MXUxKfKV7vVM_LI7-5hBV-wCswogPENkMPsREUjFu3I,3790
24
+ tests/e2e/test_sampling_params.py,sha256=ibLWtJfS35HughdOBtXD2IcyWPXoZA4R4KwXz-RzgOY,10683
25
+ tests/e2e/test_speculative_decoding.py,sha256=tj3VSJEi7r9aHjywZanlmfY4eS5Tfr5zPe9TH3PW5EY,9911
26
+ tests/e2e/test_structured_decoding.py,sha256=QYh9WjGrzm7syeLrGUawA6cOkWlQqVpTn7W6qwt65NY,1863
27
+ tests/executors/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
28
+ tests/executors/test_ray_distributed_executor.py,sha256=rMazBfirGsehEUXgpIPJkw0z7xO4cnK2kzcgxjFA6Bo,8435
29
+ tests/experimental/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
30
+ tests/experimental/test_llama3_jax_stashed.py,sha256=Ruypll_7QQOdjPmF0vDL_JVk41AHnULWuJtlgscSuZQ,8126
31
+ tests/kernels/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
32
+ tests/kernels/fused_moe_v1_test.py,sha256=cnMuvS_PD29F6whBioZlikWFrDXOMHwVPdSu2x-OJR0,10978
33
+ tests/kernels/gmm_test.py,sha256=rWE5fnp6hAV1FaGHjHjfScfIcoHuQ5wMdRGzhjt6Qnc,6820
34
+ tests/kernels/mla_v1_test.py,sha256=Rmhk8jHWeXwZmouza0o_z4NqAaac5mEo9lN1ychln9I,16076
35
+ tests/kernels/quantized_matmul_kernel_test.py,sha256=9Q3ufAG6NY9jeEFcre_IY2JbwpQdYzzhMWbXb5yfY6Q,4796
36
+ tests/kernels/ragged_kv_cache_update_v2_test.py,sha256=A12DnEqB0WtAWsD6ruF49RC4zrFcFM7CrGomElxE7jU,11396
37
+ tests/kernels/ragged_paged_attention_kernel_v2_test.py,sha256=1SSg9EzlLIdIQQw3BMoaEWbHVp30XY2A3FQS85ot4ss,11915
38
+ tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py,sha256=Ugs1DBBC-ZUuhQBomqGIqUKNiawqD539Rr1BqyNaqUQ,17007
39
+ tests/kernels/ragged_paged_attention_kernel_v3_test.py,sha256=HS60dynUGT096wCkkau4W3KJQyEQyB06P4j0LLd9-RA,15524
40
+ tests/kernels/collectives/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
41
+ tests/kernels/collectives/all_gather_matmul_kernel_test.py,sha256=ftp3CMoqiZdzD8vH0P9vNaiJx7FUICKUyxLduTqcsTk,2383
42
+ tests/layers/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
43
+ tests/layers/common/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
44
+ tests/layers/common/test_attention_interface.py,sha256=ke6h-e8CP-FhNY_ojKCYwyHgYG8aSvik1cEjCGH3VRk,5063
45
+ tests/layers/common/test_quantization.py,sha256=JcwDrNTm6UlBSV3s3mwwvpxOjqBpZDJwnYYoj3DnS7A,5344
46
+ tests/layers/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
47
+ tests/layers/jax/test_layers.py,sha256=L1xh_wniBtlfudya_WRmHUWOhEno0i6ikKE1XiBtaZs,5010
48
+ tests/layers/jax/test_qwix.py,sha256=V8MpFKJb5_evs-Z4WeZ5SxA-KAyFD6Qrex7ExywLxmE,39744
49
+ tests/layers/jax/test_rope.py,sha256=0biwYRSRsKMaRHknc8v8Tfrt0bmJKQGeQLPqR_D04mM,3565
50
+ tests/layers/jax/test_sharding.py,sha256=Hk1MWhIluOKIBx7-O9fKa1n6fF3SW7UMYsRI9AGzp_0,5914
51
+ tests/layers/jax/test_transformer_block.py,sha256=Wpgowc0ZJnv1GUxcK-Op6CCYWjpqgUM0p3EANk-YWzc,5742
52
+ tests/layers/jax/attention/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
53
+ tests/layers/jax/attention/test_common_attention.py,sha256=gXixLH2HosBp86PVwhRvwrTVVj4tl54VjrOCovwmmqM,3845
54
+ tests/layers/jax/attention/test_deepseek_v3_attention.py,sha256=hKxrUu4E8yfhIPj5V29p16xQxOXDvEQDzBZpyiAya3o,9292
55
+ tests/layers/jax/attention/test_llama4_attention.py,sha256=t1Kj0oTSFj_cVNuLl-ceZ-BY91sjx04xNRg_Epxjank,4980
56
+ tests/layers/jax/moe/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
57
+ tests/layers/jax/moe/test_deepseek_moe.py,sha256=2v7o2Svz1z6LH9tNqbL7dZtu5PSuKGiJzUccE-AMUYc,10550
58
+ tests/layers/jax/sample/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
59
+ tests/layers/jax/sample/test_rejection_sampler.py,sha256=qHvFpm-Oo6ZO0KHBN6nCB00BinbpCqxlg_QsSkAX-cI,65362
60
+ tests/layers/jax/sample/test_sampling.py,sha256=oCgI2YBnz5NCdwr2CWsiEFkddXnke1_S1tAIFP7D1oc,4098
61
+ tests/layers/jax/sample/test_sampling_metadata.py,sha256=WQCmgGkkn7sgBL9Uq7REdAkTUXq9YhbhBeuMTFtSIe8,9198
62
+ tests/layers/vllm/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
63
+ tests/layers/vllm/test_attention.py,sha256=NSbeKIi4eQj9RLiHeT-aEDvvsiHYbD3rk4uXq3_5_X8,13193
64
+ tests/layers/vllm/test_awq.py,sha256=khtLjyEO3wJlm3RM3eHVUtjAtB0BRtmmt57p-XfnFdA,14492
65
+ tests/layers/vllm/test_compressed_tensors_moe.py,sha256=Lu5M6lxFH7TetRxTNm3n6cT7su31idwZZi9MfNoP16s,7319
66
+ tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py,sha256=NivmHhqcSJE2NJVNYmndxldbA5T7FxMA6gBnz2EkPGo,16301
67
+ tests/layers/vllm/test_compressed_tensors_w8a8_int8.py,sha256=VHcCCOD1qlZst4DaBJ6vZ3PUL6n4LLFpwX9C5FKuLBY,16691
68
+ tests/layers/vllm/test_fp8.py,sha256=ZvFTg4Umgg6W2RwElkIZ_Rls_XZJ8sEW7yww2K3ztf4,666
69
+ tests/layers/vllm/test_mxfp4.py,sha256=ZOWZcBZvZV70EsrKQziBVo6hstJ9wNO3LbjQOtaKlHY,12175
70
+ tests/layers/vllm/test_unquantized.py,sha256=RvjImwpWaD7ZD6IhdeTwneRAtv0eTe22Qg84TMpc-ls,25095
71
+ tests/layers/vllm/utils.py,sha256=Qk67IqSrSovhPlWmDGFBr5vwgwtG7kcUzy69-oPgR0A,3105
72
+ tests/lora/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
73
+ tests/lora/conftest.py,sha256=OI4gPV4vNOCcfE93ccmIWQHd8-Gp9c2yGVlaSnuT4Tg,1559
74
+ tests/lora/test_bgmv.py,sha256=B1HCjh27379vCxZsd8nKMBZ8lr1JamuuWDgYiALyn18,1934
75
+ tests/lora/test_layers.py,sha256=TtIdl1SlMQ8afpkKbx6GRA9oRFAS8RjL7nqgAHxRtLM,26590
76
+ tests/lora/test_lora.py,sha256=Wqc6V7wQkobP-F8kHUkuMuiQYnxN775xlLUjDz6cEp0,5012
77
+ tests/lora/test_lora_perf.py,sha256=zcZud9Hexx6wa9qX0IvnjKyDD-i61NdIQrVO31Yx3vU,2381
78
+ tests/lora/utils.py,sha256=rY0tDZEZe58ye4-ykwrTnsiWuLcaEG57N_Rua90bDXI,2726
79
+ tests/models/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
80
+ tests/models/common/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
81
+ tests/models/common/test_model_loader.py,sha256=Sf-k_Kxdjkz-lS_0-ICfA4Yk2VXX33esP8PNG4B7FzA,17392
82
+ tests/models/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
83
+ tests/models/jax/test_deepseek_v3.py,sha256=9RY6ypfvPts3NOnvWu9n_T7pUjrvj_QY_saLOKpFg4c,16243
84
+ tests/models/jax/test_llama3.py,sha256=NYsT35yh9GzkYYcLcOo1BkBGGr14E89GtdCJJ6SFhI8,6610
85
+ tests/models/jax/test_llama4.py,sha256=MMQzTymnVUdWZ6XoOD8k9Q2ikmAk6tFSGB1C5DCi7pw,12605
86
+ tests/models/jax/test_llama_eagle3.py,sha256=DCk1ae9SLJUrqyx7uvNOmpqAAM09xb0rYNOst-Leo_M,7777
87
+ tests/models/jax/test_llama_guard_4.py,sha256=w-8cKwuTRFyzDh2mxvAofrt5xUprZyqRm5DRVRamGwE,9322
88
+ tests/models/jax/test_qwen2.py,sha256=xylG-LmHBSy76V-Yl5KiAXogpZPM2w3Mx0E61Ud5sO4,6227
89
+ tests/models/jax/test_qwen2_5_vl.py,sha256=PfB_gecAvXNrksxt8E56yP6d8ioZZWMoUIvh-OrbzJ4,26299
90
+ tests/models/jax/test_qwen3.py,sha256=NWLAZPwGIhZjW0OADk4JqU4ZPn8JGSGPwkbTQvKEc50,6021
91
+ tests/models/jax/test_weight_loading.py,sha256=RlmByQcjrsefybeNlS9wnL522be6CSR7YLcb7O5eZ-A,5205
92
+ tests/models/jax/utils/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
93
+ tests/models/jax/utils/test_multi_modal_utils.py,sha256=xrD8GijHGzb-n6z1W0okdjdNfREC1A9ZU7FQcbrx8zM,7867
94
+ tests/platforms/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
95
+ tests/platforms/test_tpu_platform.py,sha256=L0WUMncWzlWYWPAbtrE6Lhj-BuSjq-Ml2iKIjlmFGFE,2149
96
+ tests/runner/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
97
+ tests/runner/test_block_table.py,sha256=gFGF425mpWfOLjnQeQiG18TqFko8vpilJ3AiiiV1j8Y,14732
98
+ tests/runner/test_input_batch.py,sha256=7nEkB00JrhaKCKf1ep28iedYbNbuqEdaQAxYqHaXThc,8198
99
+ tests/runner/test_kv_cache.py,sha256=TvxmJNI8lM0ZNllZonHySA8NCQZ7prBgNODpYEI787E,7394
100
+ tests/runner/test_kv_cache_manager.py,sha256=dYVWQamfGwqytnumfvjRt2r3n9BRBqcSbCXGWnw1SXs,22461
101
+ tests/runner/test_multimodal_manager.py,sha256=8RbHHMvRuHg1Scc0b70tsr-tF2lfk8SZVx3InVgIryc,18591
102
+ tests/runner/test_persistent_batch_manager.py,sha256=EW6P-BtI4i59Clx-Lh84fU1GtDKF3Av2gtO-rCRYN_k,3148
103
+ tests/runner/test_speculative_decoding_manager.py,sha256=HgemtiBL_VhBheUgem3OpPj6yBK9vdJsL8VCABQdGXw,16093
104
+ tests/runner/test_structured_decoding_manager.py,sha256=pVX3z2TLR6SfBoEyRtv0BPajHbMVdcOAe4opMoxEpps,9802
105
+ tests/runner/test_tpu_runner.py,sha256=H1RjGGvNPfNNhglbiUs9J2QsokXaDtnmmtdoYRvA5_8,11649
106
+ tests/runner/test_tpu_runner_dp.py,sha256=TAEmI-JaIodgYNjjjQAAQg-q0bSbeVON5ZZE2jngfOk,50851
107
+ tests/runner/test_tpu_runner_mesh.py,sha256=kDyjdnd0vO4GQrcOAPLr9TEYA49-qDFE4gHt9IL6wlk,8638
108
+ tests/runner/test_utils.py,sha256=_R2bnKttqgg7vfPXP0Qfx38mr-4UBm2UMIbuQFAwgWk,15442
109
+ tests/spec_decode/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
110
+ tests/spec_decode/test_eagle3.py,sha256=18GbBKaMipCekyZMn24Fp-lraGEiASj2t-blohqWu7Y,12945
111
+ tests/worker/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
112
+ tests/worker/tpu_worker_test.py,sha256=lfRMW_DG2f9juR0I60uW682iDa9QvLNdtU-VLfJPUdY,17520
113
+ tpu_inference/__init__.py,sha256=2LJVEi6eR-RWHifo68n6D0SKYgg1NLrruW_E7Lz3oxg,2879
114
+ tpu_inference/env_override.py,sha256=pmL7lfs_rGCP92ya3wuWuudsCYeOMZ6tFZY82A4KkQc,365
115
+ tpu_inference/envs.py,sha256=A1Bdm5qiXhTdu-Q_yNzBpi79_nOJIDbdFF7MAMqmjxo,6662
116
+ tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
117
+ tpu_inference/tpu_info.py,sha256=lty-ngN1uUvQLlFGkWa2u5eEb5anwmcv_uyI0S95PdY,2840
118
+ tpu_inference/utils.py,sha256=0fQXcZJ4IiPGlNv_bLdkla5FeEEKEzyTsSDH-y47ouo,10641
119
+ tpu_inference/core/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
120
+ tpu_inference/core/core_tpu.py,sha256=WDD3koE_j1QhWS2BbMA2aQOZayPZm4tYPvzL4YCX2jY,33294
121
+ tpu_inference/core/disagg_executor.py,sha256=HZpgYMVxRxm0RQxO4l8IDYBWJ6Z3Tac6xavc5otcirc,4657
122
+ tpu_inference/core/disagg_utils.py,sha256=lv8MAVoAjtcmTaenUXVokg2q3d0tzsma86UiQlQ3omY,1492
123
+ tpu_inference/core/sched/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
124
+ tpu_inference/core/sched/dp_scheduler.py,sha256=-7d2zopJ5ZJFIJ8LbHsm_4bBBtP7qrim4XWVPDF6vrg,34960
125
+ tpu_inference/distributed/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
126
+ tpu_inference/distributed/jax_parallel_state.py,sha256=xMK0tEtblh37_LoHvp1-6qPI8AgX4HkE0ATuc7fdHKs,2798
127
+ tpu_inference/distributed/tpu_connector.py,sha256=3rR0y2P1MOOSM8nBfvl95ZQcVKMms3rL8zTdnxUmSms,29946
128
+ tpu_inference/distributed/utils.py,sha256=8pTkqI81b7Gkurn6M4zepoTUmTRaab3kfrH4ncAf5ns,3738
129
+ tpu_inference/executors/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
130
+ tpu_inference/executors/ray_distributed_executor.py,sha256=vz82tLPkQqwwUmwny1em_PrjNFZuroQPnXaEQAC5iWY,16980
131
+ tpu_inference/experimental/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
132
+ tpu_inference/experimental/llama3_jax_stashed.py,sha256=39XTuG-0C5pZe1oDznm6iCrvccZ_2CnC488YsvhxIho,11488
133
+ tpu_inference/kernels/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
134
+ tpu_inference/kernels/collectives/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
135
+ tpu_inference/kernels/collectives/all_gather_matmul.py,sha256=TtQWY0lNj8699JwDmjqbRrdku-3oAw5WkuuoFPS49AY,27597
136
+ tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py,sha256=OEPf4q08IeIFyJfzizgRs6kSD7w35NeZDRIn7CcZ344,1468
137
+ tpu_inference/kernels/collectives/util.py,sha256=LbLD6lOxuszbUsykF89gWQqEJUICCZsfzam3EJDPnFE,1859
138
+ tpu_inference/kernels/flash_attention/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
139
+ tpu_inference/kernels/flash_attention/kernel.py,sha256=n8gmAFVfchMXlyaSEj8xXJm6AadFt26edQihPRdithY,25897
140
+ tpu_inference/kernels/fused_moe/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
141
+ tpu_inference/kernels/fused_moe/v1/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
142
+ tpu_inference/kernels/fused_moe/v1/kernel.py,sha256=B0qWaa5vphIa3MJmeTbvpBMh9JJlRWNpmoORrz79Cvk,64990
143
+ tpu_inference/kernels/megablox/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
144
+ tpu_inference/kernels/megablox/common.py,sha256=CoJPNom6anJU9B4i05d2skytJEvNS994DYo0eEyVGuY,1639
145
+ tpu_inference/kernels/megablox/gmm.py,sha256=rVW70SGPshR9XvHiwzmskX4_yeD4nE8or3RfabwcCLM,24240
146
+ tpu_inference/kernels/mla/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
147
+ tpu_inference/kernels/mla/v1/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
148
+ tpu_inference/kernels/mla/v1/kernel.py,sha256=oovjb0x3qz08IL_KVjLLbNbcEcFXip55fqgIgfnl3RA,49758
149
+ tpu_inference/kernels/quantized_matmul/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
150
+ tpu_inference/kernels/quantized_matmul/kernel.py,sha256=-A9Kd2ApHWgPvCaUPfjM5JooLz_iCfWV1UT0taaZaAo,16264
151
+ tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py,sha256=3zhIm73JEE8qOty2_0v3AJlVz13k6qMB5wlXBDyC1EM,35130
152
+ tpu_inference/kernels/quantized_matmul/util.py,sha256=rf6nIiAj9I2cj4LDvtaZGhcLXEc94o2xgMWasnFaREM,1943
153
+ tpu_inference/kernels/ragged_paged_attention/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
154
+ tpu_inference/kernels/ragged_paged_attention/v2/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
155
+ tpu_inference/kernels/ragged_paged_attention/v2/kernel.py,sha256=462jgsWdnaQfO9K1Y99cJ-qidYWXZMc5GdoY9enQEWY,35019
156
+ tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256=y9-C_F28WGd282Ra_DqwTbHyUIIj2jyWY3DiX8yozHY,11080
157
+ tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
158
+ tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
159
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=HVTQ4LJiEkWiYuUV1ey-2K2u6IULjJQ2dbX3qpo3FLA,60593
160
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=VutC0CwPfF-luuRSPv6b7QiFt2EBiCPdoTMtOrFFZtI,60391
161
+ tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=sG67fBe8ckXdfvO7c9gfGFhu6_8owir8ZE6IOyHhNFY,231477
162
+ tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py,sha256=WusgnI6oDRsUoF8lp4vsaPepKO8oTJLlPSlLDpr3-7Y,25025
163
+ tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=VVYHEHmANvEddEKx8IPTRSXDykwzEOJa2GZKNv7nwnM,1755
164
+ tpu_inference/layers/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
165
+ tpu_inference/layers/common/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
166
+ tpu_inference/layers/common/attention_interface.py,sha256=WNEXNj1_6mDNS4KDXJRu9hkbJmKFlsp78txbqbDWhTo,13712
167
+ tpu_inference/layers/common/attention_metadata.py,sha256=rmipY517sefHe4owxC5USkm4lbL4zd4LZKokDYGECQo,1425
168
+ tpu_inference/layers/common/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
169
+ tpu_inference/layers/common/fused_moe_gmm.py,sha256=xzrFK1fRZXsF_a1robY1qe5I9rQ3t2kcjhN4KHmt75Q,19862
170
+ tpu_inference/layers/common/quant_methods.py,sha256=SCm9g7bE02XSMONmOCuT0vfHeTP6RzGQ57aTj919HgM,772
171
+ tpu_inference/layers/common/quantization.py,sha256=63-kb4XR3D1mCryBYhRy881W2X52m7kF_CmHeETo2R8,9216
172
+ tpu_inference/layers/common/sharding.py,sha256=curCejZPj8ND4rxjWEbwRozkFYlK_HlpIyTywhDHcWU,26171
173
+ tpu_inference/layers/common/utils.py,sha256=k1OWrJJI6E58TCNUXO7TFc5l_9XmwL3d7N2U4QE-zPs,4417
174
+ tpu_inference/layers/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
175
+ tpu_inference/layers/jax/base.py,sha256=UhT4ut_59ynUPdaZGpMPSCQkPTWXA9BxkaPy7lDhoLI,6350
176
+ tpu_inference/layers/jax/constants.py,sha256=YQJOeAbja1yTbPhoOWMp24OF1RCMwPybK1NIwPrrYJ0,3329
177
+ tpu_inference/layers/jax/layers.py,sha256=elv04eCMFj5Jt3SF0PXxyuQPTwmJDgsuvZ9oK88HTso,11208
178
+ tpu_inference/layers/jax/misc.py,sha256=Jdxv8SAT1yVuM_1_lGWImRSXlu2xGLnXI-TRGRNsBYw,1141
179
+ tpu_inference/layers/jax/pp_utils.py,sha256=gP3Xt-Pinm6E7yJ9jtsSnmmoz9GmgBN83TkSgIrz0OA,1726
180
+ tpu_inference/layers/jax/rope.py,sha256=FbZKJPd9T0IDaZyOJkrFl2CL1on1womCzZBiUPLU0O4,11924
181
+ tpu_inference/layers/jax/rope_interface.py,sha256=cPqVpKG5_SU7S7xcrMEaPBJLqi1nC4uMN-1S-dmb0mQ,8950
182
+ tpu_inference/layers/jax/transformer_block.py,sha256=HTI0fYPQd23UbnJSB_pL2K3un3q_i3guvJiNCUReVRs,4492
183
+ tpu_inference/layers/jax/attention/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
184
+ tpu_inference/layers/jax/attention/attention.py,sha256=_N5W4ox8EzC1CZYcIhsEi35X8WCIMFEBlSzVtDDcTu8,10623
185
+ tpu_inference/layers/jax/attention/deepseek_v3_attention.py,sha256=KP-hgck-wTzTcwDNB08DwNiqsE-6OD4tQ1jLVwWQvEw,22427
186
+ tpu_inference/layers/jax/attention/gpt_oss_attention.py,sha256=EM1kJpr77VHh95aSD5UnSJazB_anS_7PyaD8TixVMrY,9241
187
+ tpu_inference/layers/jax/attention/llama4_attention.py,sha256=QzBDoEioI9mMdI1T2LNlsr89iaGl234e-9s202YWS8M,6713
188
+ tpu_inference/layers/jax/moe/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
189
+ tpu_inference/layers/jax/moe/deepseek_v3_moe.py,sha256=5j6TJO8fAB2Yv6mVAeM2F9WLe4QDM9bf6zxtdKjHjCQ,26456
190
+ tpu_inference/layers/jax/moe/gpt_oss_moe.py,sha256=-uliFqHJFOTT9WJCEpGhkImOXMSoo3aePXMOmKXlgmk,6771
191
+ tpu_inference/layers/jax/moe/moe.py,sha256=E7L8bJucTVke89o048GAbWdtuQIL5oDz-MkW0NK4E00,10114
192
+ tpu_inference/layers/jax/sample/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
193
+ tpu_inference/layers/jax/sample/rejection_sampler.py,sha256=VqN0mxi7Xg58w4EXS625ndC8NyA_UZMV9bjFM1mkvrY,21000
194
+ tpu_inference/layers/jax/sample/sampling.py,sha256=IfJBFSXuTdd0QELn8Opmh7HgdzKreIwGYUOskTFp4aI,3888
195
+ tpu_inference/layers/jax/sample/sampling_metadata.py,sha256=bip7TQcw-VHyN6072zBQY-tA0-QTyJpnuYg04mw9Sv0,3136
196
+ tpu_inference/layers/vllm/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
197
+ tpu_inference/layers/vllm/attention.py,sha256=LMQbS2KAup0Q-mmN5pzV6uUs-qdGpTSH8eV6ByHde9g,7370
198
+ tpu_inference/layers/vllm/fused_moe.py,sha256=E4JeuCekVYsvMLJkccOrP690GL2Q_EWlLwW3ZK5NT-0,4013
199
+ tpu_inference/layers/vllm/linear.py,sha256=KRScVrEGys3NLpDzG0UieHb371UJR1R_ct6LR84_-iE,2428
200
+ tpu_inference/layers/vllm/process_weights/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
201
+ tpu_inference/layers/vllm/process_weights/cleanup_sharding.py,sha256=vg9PdWdY8caYKBs5G_YhkKA4MdAR213knoLv_TJlyDU,9793
202
+ tpu_inference/layers/vllm/process_weights/fused_moe_weights.py,sha256=vVrzVzrJ6_vUMPI_Nzqmqco2yeZb9O3CEzNII2rXWU0,14936
203
+ tpu_inference/layers/vllm/process_weights/linear_weights.py,sha256=3Qx-Dgdx5Khjb9B0LXmFVUz7Tc8bXf6esSfk7MWicwM,6068
204
+ tpu_inference/layers/vllm/quantization/__init__.py,sha256=r9oDaXh0TiDSnh2WOWEYfPDRaH3aU9uW2ANHrezZZjw,2450
205
+ tpu_inference/layers/vllm/quantization/awq.py,sha256=5HdRtJ1E5adCKmDIlPkIzXdgdBsSakrmRPKnQjryEwk,8595
206
+ tpu_inference/layers/vllm/quantization/configs.py,sha256=0q-gRrR7sxgUty1OzmIc6MrMH9dpuN_DYHISskvlpk8,4925
207
+ tpu_inference/layers/vllm/quantization/fp8.py,sha256=z4xXpqy7I37p6rBZjlCQRomFQzbWHOw1xWkHN3_bndw,4541
208
+ tpu_inference/layers/vllm/quantization/mxfp4.py,sha256=q7EnVQlbdTy_qicmRo_mn6t5Q3fEt_cs31SUUVga8hU,8597
209
+ tpu_inference/layers/vllm/quantization/unquantized.py,sha256=YFZHAjmrjWnuZuwx-lG0Eka9BNqCvIq5kNbEY6vAn3Y,10795
210
+ tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
211
+ tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=VuEqI7HpN39Xee-z5ohuqlu9PdlcBpFJpfe79PsJhx0,5930
212
+ tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py,sha256=8r1dT0UexEQD9-4kGiky1x7ITVpMPU90bzs-6HZQ51E,7841
213
+ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
214
+ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py,sha256=W84yM33UCkCF_AZRNCoPGLqFI_EO2WHLcCfzx5TWzl4,9529
215
+ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=c5lqLSg-6u6Y56XYH9m1-20hlmNQ_zIB832NXDLJWJ4,6816
216
+ tpu_inference/lora/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
217
+ tpu_inference/lora/torch_lora_ops.py,sha256=YR3Hj8nLLiQ-6wXy4uFsjQxFTbJYZ4o5dh_L0mlXg-o,3261
218
+ tpu_inference/lora/torch_punica_tpu.py,sha256=qTnXZGLoOgvukSxeunO_SfpPTlkq9GlMj9H7zVYg9LE,12680
219
+ tpu_inference/models/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
220
+ tpu_inference/models/common/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
221
+ tpu_inference/models/common/model_loader.py,sha256=gSaY_PCRtVjx-lKsNROGmgR41E_oMba2dVxtQONADvI,21878
222
+ tpu_inference/models/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
223
+ tpu_inference/models/jax/deepseek_v3.py,sha256=mje3RgxE1NwKWVLgJnPq3ebWB1J8T6YGHT2TtxN10Dg,45031
224
+ tpu_inference/models/jax/gpt_oss.py,sha256=bgdsCx3UcTqEJatWBYbma5HNHH8GEaHN4aL5IsAeSmM,21592
225
+ tpu_inference/models/jax/jax_intermediate_tensor.py,sha256=XKpDgPkOiRtYaPrW76ILxcp2uFfSiE1JMdqHWGo0-Ss,3179
226
+ tpu_inference/models/jax/llama3.py,sha256=FjTGC69V_EJmvb5BIqYu3V5NS1Pvy-5Pb34kMn5YU5U,16317
227
+ tpu_inference/models/jax/llama4.py,sha256=Ssycb5fcGjhJYg8FfcNckVhow7bvVt0FJbbpHinzMAA,30206
228
+ tpu_inference/models/jax/llama_eagle3.py,sha256=_wnljvb8lLCQ0Z3Vuw0QI7F6b41x6I1WuvstZWGvCYE,13051
229
+ tpu_inference/models/jax/llama_guard_4.py,sha256=R4wo45s1JsVD39t8JeAItujGoi-sl43HBH95hr7qEVw,15845
230
+ tpu_inference/models/jax/qwen2.py,sha256=bart2yYGv0J-lNbk8Hk5jn5IF6j_Jp8YKSEjwVU_y24,14038
231
+ tpu_inference/models/jax/qwen2_5_vl.py,sha256=3g3tUt7c83fKOdiMzuq2VyldCyeXoCBGrVYfqyIWwGE,50370
232
+ tpu_inference/models/jax/qwen3.py,sha256=jVOOVrBFnxRIZ_Euo90iCga8rORpz0Kqs79uKqsFwEQ,11678
233
+ tpu_inference/models/jax/utils/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
234
+ tpu_inference/models/jax/utils/file_utils.py,sha256=8iZcGNvF1N0gNioH8fBlVYTSGYn4fC2WvmlTyeDZyZM,3415
235
+ tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=c2LRXdOPi3F779yg2UX-DnuFDxF1JciTcFa09iODxZs,6695
236
+ tpu_inference/models/jax/utils/weight_utils.py,sha256=0xyjGlDSrA09gtb4plw9yX57VPMgn3o5WNl6mXPDU70,23121
237
+ tpu_inference/models/jax/utils/qwix/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
238
+ tpu_inference/models/jax/utils/qwix/qwix_utils.py,sha256=w3wmDb1drJxOK1mVRVMORznqKbtZqFfi7H0Ib_k-iW8,29526
239
+ tpu_inference/models/vllm/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
240
+ tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=mqD0qSnRY28CJ-ZU9BLXPD4zcMui0_P2vBZsCn2KWTs,13053
241
+ tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=vsXQnC2aZ_mHKb-7d9UeN28lfawfApNTm5asUMgEhgo,1762
242
+ tpu_inference/platforms/__init__.py,sha256=BK6rwAhiqVSAUJ9m9EehSKetA6hEPe92flD9Ei076WQ,649
243
+ tpu_inference/platforms/tpu_platform.py,sha256=loDc6hi9DlBmcoN6CjuEt6GKYL7tXY29D086s00_M4o,9474
244
+ tpu_inference/runner/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
245
+ tpu_inference/runner/block_table.py,sha256=K3Ic8EgPM08d_C5nEN60mxoRydlaQWySAemf_8Q_qVw,4175
246
+ tpu_inference/runner/compilation_manager.py,sha256=BFjOzJUyEJTmUZAvGCm3yeqoY7Kkw2JKc_A3CzRoN7o,42112
247
+ tpu_inference/runner/input_batch.py,sha256=bx221NX2IOWzrtopss-B-2ZKW4y-U6nQpG09PjpUziw,18273
248
+ tpu_inference/runner/kv_cache.py,sha256=xpB6VTrT3lIq5JNNPJTVEnHFgehIzgxKNIHxxXIxwKI,6046
249
+ tpu_inference/runner/kv_cache_manager.py,sha256=u6pXaWPzmPe34lXiy-acAdGBmp9WEQrGvksyBfGBRdM,23342
250
+ tpu_inference/runner/lora_utils.py,sha256=LgnrePvkBFyMvQqSp9VfrIbWPBwpWG4_iUaj3lX0Os8,4448
251
+ tpu_inference/runner/multimodal_manager.py,sha256=sNzj_U4XTRQtuslKljxbcS6NRNlFB_bN6l0qpnqrlfM,10315
252
+ tpu_inference/runner/persistent_batch_manager.py,sha256=aCeTyqCgBnQy_6hXjiNLtF81ekG0-YwlQiWeJhx-pdM,13838
253
+ tpu_inference/runner/speculative_decoding_manager.py,sha256=-eSxTIGXbRWRZjHJfikb7kfqbtr_cj7Pca9zInWSn1w,10790
254
+ tpu_inference/runner/structured_decoding_manager.py,sha256=sj1fPrit0qdhcQtDbue5kpxos7zL16_dZQ5YSXTDbzg,4148
255
+ tpu_inference/runner/tpu_runner.py,sha256=cgIyZiI3UjpvPWhNRL-mCSnssbbDNt00g5idAzwgWR0,80736
256
+ tpu_inference/runner/utils.py,sha256=lKqL5nxGTk7ufzJRNdp4udn2bPu3jIX52W7akXgSrHc,17133
257
+ tpu_inference/spec_decode/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
258
+ tpu_inference/spec_decode/jax/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
259
+ tpu_inference/spec_decode/jax/eagle3.py,sha256=5WtEbkgzXpmFz374ibQD5IIcRro4d0SNeCYgBv2nM1c,19678
260
+ tpu_inference/worker/__init__.py,sha256=Q9FlRO2IfSE9yEaiAYzWkOMBJPCaNYqh4ihcp0t0BQs,574
261
+ tpu_inference/worker/tpu_worker.py,sha256=ntwCibPyiw-z8aMUdtu8usqU_q2b0u7diWNOmpjG_6o,21651
262
+ tpu_inference-0.13.2.dev20251230.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
263
+ tpu_inference-0.13.2.dev20251230.dist-info/METADATA,sha256=08-onD7oUGsgmWyILrp51XmacHdKXu1X824ws4eoh88,5767
264
+ tpu_inference-0.13.2.dev20251230.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
265
+ tpu_inference-0.13.2.dev20251230.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
266
+ tpu_inference-0.13.2.dev20251230.dist-info/RECORD,,
@@ -1,208 +0,0 @@
1
- from typing import Optional, Union
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import torch
6
- from jax.experimental.shard_map import shard_map
7
- from jax.sharding import Mesh, NamedSharding
8
- from jax.sharding import PartitionSpec as P
9
- from torchax.interop import torch_view
10
- from torchax.ops.mappings import t2j
11
-
12
- from tpu_inference import envs
13
- from tpu_inference.kernels.quantized_matmul.kernel import (
14
- quantized_matmul_kernel, xla_quantized_matmul)
15
-
16
-
17
- def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
18
- mesh: Mesh, weight_sharding: P) -> jax.Array:
19
- """
20
- Wrapper around the quantized matmul kernel.
21
-
22
- Args:
23
- x: Activation.
24
- w_q: Weight quantized array. [n_output_features, n_input_features]
25
- w_s: Weight quantization scale. [n_output_features]
26
- mesh: Mesh to shard on.
27
- weight_sharding: PartitionSpec for the weight tensor.
28
-
29
- Returns:
30
- Output of the quantized matmul.
31
- """
32
-
33
- # NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
34
- # with the kernel and thus we disable it for now.
35
- if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
36
- out_axis, in_axis = weight_sharding
37
- x_sharding = P(None, in_axis)
38
- scale_sharding = P(out_axis, )
39
- out_sharding = P(None, out_axis)
40
-
41
- x = jax.lax.with_sharding_constraint(x,
42
- NamedSharding(mesh, x_sharding))
43
-
44
- def wrapper(x, w_q, w_s):
45
- output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
46
- if in_axis:
47
- output = jax.lax.psum(output, axis_name=in_axis)
48
- return output
49
-
50
- return shard_map(wrapper,
51
- mesh=mesh,
52
- in_specs=(x_sharding, weight_sharding,
53
- scale_sharding),
54
- out_specs=(out_sharding),
55
- check_rep=False)(x, w_q, w_s)
56
- else:
57
- return xla_quantized_matmul(x, w_q, w_s)
58
-
59
-
60
- def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
61
- split_sizes: list[int],
62
- n_shards: int, dim: int):
63
- """
64
- Reorder a replicated concatenated tensor such that when sharded on multiple chips, each shard is a concatenation of the shards of the individual tensors.
65
- For example, let the concatenated_tensor be:
66
- AAAAAAAAAAAABBBBBBBBCCCC
67
- 12 As 8 Bs 4 Cs
68
- and let the split_sizes = [12, 8, 4] and n_shards = 4.
69
- The output is:
70
- AAABBCAAABBCAAABBCAAABBC
71
- In other words, it reorders the input tensor into 4 segements, with each segment corresponding to a shard and being AAABBC.
72
- Args:
73
- concatenated_tensor: the tensor, concatenated on the dimension specified by `dim`.
74
- split_sizes: each individual tensor's size on the dimension specified by `dim`.
75
- n_shards: num of shards.
76
- dim: the dimension on which the concatenated_tensor is concatenated.
77
- """
78
- # Split the concatenated tensor into individual tensors.
79
- split_tensors = []
80
- start_offset = 0
81
- old_shape = concatenated_tensor.shape
82
- # New shape ensures each split_tensor[i] maps to a tensor in ith shards
83
- new_shape = old_shape[:dim] + (n_shards, -1) + old_shape[dim + 1:]
84
- for split_size in split_sizes:
85
- split_tensor = jax.lax.slice_in_dim(concatenated_tensor,
86
- start_offset,
87
- start_offset + split_size,
88
- axis=dim)
89
- split_tensors.append(split_tensor.reshape(new_shape))
90
- start_offset += split_size
91
- # While maintaining 0th dim as a shard dim, we concatenate along 1th dim to
92
- # to create concatenated tnensor where 0th dim maps to shard dim.
93
- reordered_tensor = jnp.concatenate(split_tensors, axis=dim + 1)
94
- return reordered_tensor.reshape(old_shape)
95
-
96
-
97
- def slice_sharded_tensor_for_concatenation(sharded_tensor: jax.Array,
98
- split_sizes: list[int],
99
- n_shards: int):
100
- """
101
- Slice the input tensor which is sharded on multiple chips (on the last dim) into individual tensors with the same sharding.
102
- For example, let the sharded_tensor be:
103
- AAABBC | AAABBC | AAABBC | AAABBC
104
- Shard0 Shard1 Shard2 Shard3
105
- and let the split_sizes = [12, 8, 4] and n_shards = 4.
106
- The output is a list of 3 tensors:
107
- AAA | AAA | AAA | AAA
108
- BB | BB | BB | BB
109
- C | C | C | C
110
- Shard0 Shard1 Shard2 Shard3
111
- In other words, each individual tensor is a slice of the input tensor with the same sharding.
112
- Args:
113
- sharded_tensor: the input tensor, sharded on the last dim.
114
- split_sizes: each individual tensor's size on the last dim.
115
- n_shards: num of shards.
116
- """
117
- new_shape = sharded_tensor.shape[:-1] + (n_shards, -1)
118
- # New shape ensures each sharded_tensor[:, i] maps to a tensor in ith shards
119
- sharded_tensor = sharded_tensor.reshape(new_shape)
120
-
121
- split_tensors = []
122
- start_offset = 0
123
- for split_size in split_sizes:
124
- assert split_size % n_shards == 0
125
- sz = split_size // n_shards # size of this split tensor per shard
126
- end_offset = start_offset + sz
127
- # Because we are slicing over last dim, sharding dim remains intact.
128
- # Therefore, splitting happens locally.
129
- split_tensor = sharded_tensor[..., start_offset:end_offset]
130
- split_tensors.append(split_tensor.reshape(new_shape[:-2] + (-1, )))
131
- start_offset = end_offset
132
-
133
- return split_tensors
134
-
135
-
136
- def torch_to_jax_param(
137
- tensor: torch.Tensor,
138
- sharding: NamedSharding,
139
- output_sizes: Optional[int],
140
- n_shards: int,
141
- fused: bool,
142
- dim: int = 0,
143
- jax_dtype: Optional[jnp.dtype] = None,
144
- ) -> Union[torch.nn.Parameter, torch.nn.ParameterList]:
145
- if output_sizes is None:
146
- output_sizes = [tensor.shape[0]]
147
-
148
- tensor = t2j(tensor, use_dlpack=False)
149
- if jax_dtype:
150
- tensor = tensor.astype(jax_dtype)
151
-
152
- if fused:
153
- tensor = reorder_concatenated_tensor_for_sharding(
154
- tensor, output_sizes, n_shards, dim)
155
- tensor = jax.device_put(tensor, sharding)
156
- param = torch.nn.Parameter(torch_view(tensor), requires_grad=False)
157
- else:
158
- tensors = []
159
- start_offset = 0
160
- for size in output_sizes:
161
- end_offset = start_offset + size
162
-
163
- tensor_split = jax.lax.slice_in_dim(tensor,
164
- start_offset,
165
- end_offset,
166
- axis=dim)
167
- tensor_split = jax.device_put(tensor_split, sharding)
168
- tensor_split = torch.nn.Parameter(torch_view(tensor_split),
169
- requires_grad=False)
170
- tensors.append(tensor_split)
171
-
172
- start_offset = end_offset
173
- param = torch.nn.ParameterList(tensors)
174
- return param
175
-
176
-
177
- MODEL_MATMUL_FUSION_TRUTH_TABLE = {
178
- ("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "QKVParallelLinear"):
179
- True,
180
- ("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
181
- False,
182
- ("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "QKVParallelLinear"):
183
- False,
184
- ("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
185
- False,
186
- ("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "QKVParallelLinear"):
187
- False,
188
- ("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
189
- False,
190
- ("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "QKVParallelLinear"):
191
- False,
192
- ("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
193
- False,
194
- ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "QKVParallelLinear"):
195
- False,
196
- ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "MergedColumnParallelLinear"):
197
- False,
198
- ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "QKVParallelLinear"):
199
- False,
200
- ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "MergedColumnParallelLinear"):
201
- False,
202
- }
203
-
204
-
205
- def get_model_matmul_fusion_assignment(model_name: str, batch_size: int,
206
- tp_size: int, layer_name: str):
207
- key = (model_name, batch_size, tp_size, layer_name)
208
- return MODEL_MATMUL_FUSION_TRUTH_TABLE.get(key, True)
File without changes
@@ -1,5 +0,0 @@
1
- qwix:
2
- rules:
3
- # NOTE: each entry corresponds to a qwix.QuantizationRule
4
- - module_path: '.*'
5
- weight_qtype: 'float8_e4m3fn'
@@ -1,6 +0,0 @@
1
- qwix:
2
- rules:
3
- # NOTE: each entry corresponds to a qwix.QuantizationRule
4
- - module_path: '.*'
5
- weight_qtype: 'float8_e4m3fn'
6
- act_qtype: 'float8_e4m3fn'
@@ -1,5 +0,0 @@
1
- qwix:
2
- rules:
3
- # NOTE: each entry corresponds to a qwix.QuantizationRule
4
- - module_path: '.*'
5
- weight_qtype: 'int8'
@@ -1,6 +0,0 @@
1
- qwix:
2
- rules:
3
- # NOTE: each entry corresponds to a qwix.QuantizationRule
4
- - module_path: '.*'
5
- weight_qtype: 'int8'
6
- act_qtype: 'int8'
@@ -1,105 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
-
4
- import torch
5
-
6
- # MXFP4 constants
7
- MXFP4_BLOCK_SIZE: int = 32
8
- # Exponent-only e8m0 scale bias used by MXFP4 scales
9
- MXFP4_SCALE_BIAS: int = 127
10
- # Name used in config.json quantization_config["quant_method"]
11
- MXFP4_QUANT_METHOD: str = "mxfp4"
12
-
13
- # Precompute a small LUT once; move to device on demand (cheap 16-element copy)
14
- FP4_LUT = torch.tensor(
15
- [
16
- 0.0,
17
- 0.5,
18
- 1.0,
19
- 1.5,
20
- 2.0,
21
- 3.0,
22
- 4.0,
23
- 6.0, # 0b0000-0b0111
24
- -0.0,
25
- -0.5,
26
- -1.0,
27
- -1.5,
28
- -2.0,
29
- -3.0,
30
- -4.0,
31
- -6.0, # 0b1000-0b1111
32
- ],
33
- dtype=torch.float32)
34
-
35
-
36
- def unpack_mxfp4(packed: torch.Tensor) -> torch.Tensor:
37
- """Unpack uint8 (..., 16) -> fp4 values (..., 32) using low->high nibble order.
38
-
39
- Returns float32 values corresponding to FP4 codebook entries.
40
- """
41
- assert packed.dtype == torch.uint8
42
- low = packed & 0x0F
43
- high = (packed >> 4) & 0x0F
44
- idx = torch.stack([low, high], dim=-1).flatten(-2)
45
- lut = FP4_LUT.to(packed.device)
46
- return lut[idx.long()]
47
-
48
-
49
- def e8m0_to_fp32(u8: torch.Tensor) -> torch.Tensor:
50
- """Convert e8m0 uint8 exponents to power-of-two scales using MXFP4_SCALE_BIAS.
51
-
52
- Uses ldexp for exact power-of-two scaling: 1.0 * 2**(u8 - bias).
53
- """
54
- exponents = (u8.to(torch.int32) - int(MXFP4_SCALE_BIAS)).to(torch.int32)
55
- ones = torch.ones_like(u8, dtype=torch.float32)
56
- return torch.ldexp(ones, exponents)
57
-
58
-
59
- def dequant_mxfp4_to_bf16(blocks_u8: torch.Tensor,
60
- scales_u8: torch.Tensor) -> torch.Tensor:
61
- """Dequantize MXFP4 blocks/scales into bfloat16 values.
62
-
63
- Args:
64
- blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte holds 2 FP4 codes.
65
- scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per 32-value block.
66
-
67
- Returns:
68
- torch.bfloat16 tensor with last logical dimension K = Kb * 32.
69
- """
70
- if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
71
- raise ValueError(
72
- f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}"
73
- )
74
- # Unpack FP4 codes to float32 values [..., Kb, 32]
75
- fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32)
76
- # Compute power-of-two scales and apply per block
77
- scales = e8m0_to_fp32(scales_u8).unsqueeze(-1) # (..., Kb, 1)
78
- full = (fp4_vals * scales).reshape(*fp4_vals.shape[:-2],
79
- fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
80
- return full.to(torch.bfloat16)
81
-
82
-
83
- def unpack_mxfp4_to_fp32(
84
- blocks_u8: torch.Tensor,
85
- scales_u8: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
86
- """Decode MXFP4 packed blocks and e8m0 scales to float32 codes and scales.
87
-
88
- Args:
89
- blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte packs two FP4 codes.
90
- scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per block.
91
-
92
- Returns:
93
- (codes_fp32, scales_fp32), where
94
- - codes_fp32 has shape [..., Kb*32] and dtype float32
95
- - scales_fp32 has shape [..., Kb] and dtype float32
96
- """
97
- if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
98
- raise ValueError(
99
- f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}"
100
- )
101
- fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32) float32
102
- codes_fp32 = fp4_vals.reshape(*fp4_vals.shape[:-2],
103
- fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
104
- scales_fp32 = e8m0_to_fp32(scales_u8) # (..., Kb) float32
105
- return codes_fp32, scales_fp32