tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  tests/test_base.py,sha256=Ct5WFRMHL7IHEIxk8FrzAvO8m0xFuDpzDBKkAKKAL2Q,7341
3
- tests/test_envs.py,sha256=h502VxL2gvhECm8u5uDh5JTGvhFf_DfQO88SpqOFMzE,7135
3
+ tests/test_envs.py,sha256=Woyfp_d5HS-uTGo4_u9dYlBbgmhfIEoFb-Rx_k7YXD4,6298
4
4
  tests/test_quantization.py,sha256=IT5ASyS1uuWcxc22kRtBcA-V4j3Z3hb7pMztm3GOlBs,34445
5
5
  tests/test_tpu_info.py,sha256=ZrwlMsp8ffITkS_b8Q1t_QG-a-WVAd4NUcjHhGibcsI,4670
6
- tests/test_utils.py,sha256=GIXLdd-x4gnqSLrySXGk22phqPc8MegFd7ph1Jj8OcU,8182
6
+ tests/test_utils.py,sha256=Mta5ZzYCgRAh1-BjcOvvx9iQ9DnnXLps7oDHxVQp2yE,8236
7
7
  tests/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  tests/core/test_core_tpu.py,sha256=r496rk1eOsK_F4nvm9zprl_T-RcO6eCUb7LuVReOZno,21413
9
9
  tests/core/test_disagg_executor.py,sha256=QdE2YZs08EyDDCmSjhiXkXqQ9BJTgO6csr_E1xkkfSg,2256
@@ -11,37 +11,37 @@ tests/core/test_disagg_utils.py,sha256=alktTGppaGdg-_un0Amz8Y0IDQz-xNJN0dXG-YApE
11
11
  tests/core/test_dp_scheduler.py,sha256=IwCR1Vs96V4CQdWA051rNaYxxr2V_byA1yx9HWyRoMg,37339
12
12
  tests/core/test_init.py,sha256=NEFI5A9eKGu4rmeJ2iqd0EmhlA3bzbVkXmMi1PV1b9U,1687
13
13
  tests/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- tests/kernels/fused_moe_v1_test.py,sha256=sQ6gvpti94fpPYrSZn7frPPNjqbVmRibFtenVrGGCA4,10403
14
+ tests/kernels/fused_moe_v1_test.py,sha256=c6zbSHQDzOseeyL9VCjQeP7zayNnwYf059CPlKcvZzQ,3137
15
15
  tests/kernels/mla_v1_test.py,sha256=oZc4TCgquiG0KOeWfv46yJbUIpro_CgCMFc7vzyB7t8,11646
16
16
  tests/kernels/quantized_matmul_kernel_test.py,sha256=od5-zXFjcsc_gWGRDrREL8E_ftymNniQVTzgtkBo_Gc,5679
17
17
  tests/kernels/ragged_kv_cache_update_v2_test.py,sha256=6-HjP5CoUG-kcuP8MS-JJVMiBnPRo_zadS3VInnO0D4,10821
18
18
  tests/kernels/ragged_paged_attention_kernel_v2_test.py,sha256=pWqo9UYF0tzwgBKO_xYw-TYSPrtAsKcMK5Haj8hFG7I,11340
19
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py,sha256=vLfe1I_vLdf0SqtBuBL7QHLSklrhWOOzYF-I_I3rdNo,16309
19
+ tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py,sha256=JhIElqUZIRqIsfQ3U1RUzSiH_gz_SabAqDosGGZ2tlA,16321
20
20
  tests/kernels/ragged_paged_attention_kernel_v3_test.py,sha256=Hrd8iUkS1pS3rxeTyY53aYRg_ZL_d3NqgBXvOgnigSU,14838
21
21
  tests/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  tests/lora/conftest.py,sha256=EXjwE1CjmUUlMEXpyE3UwxvgrKUllE73I8BNKfP1FTc,984
23
23
  tests/lora/test_bgmv.py,sha256=gQxWsJdNX2nkrE2xyrG0exwf3E2eHm2k2nkEXoANuQc,1359
24
- tests/lora/test_layers.py,sha256=6B4HhMAItQmt0hPAQgyXgwSYs7b3bIbUf6LaPsqXLzY,25923
24
+ tests/lora/test_layers.py,sha256=21ekYlsK36r1GPZOfzs7E-KIsfI1JcuZl1E6vaQbHf4,26273
25
25
  tests/lora/test_lora.py,sha256=wJiF1P1BDnPN8TLX2tlFtdZ_QCkV-S9nPl6_uR6DqFc,4439
26
- tests/lora/utils.py,sha256=rY0tDZEZe58ye4-ykwrTnsiWuLcaEG57N_Rua90bDXI,2726
27
- tpu_inference/__init__.py,sha256=p4MaepRdN7723FUNE-3pOMxZWjFn4_TVFgjrNyty4JE,2304
26
+ tests/lora/utils.py,sha256=dR_v1H20vPVjFHdBhDajWOz0WJZlKuPLgMFQsME0LtA,3009
27
+ tpu_inference/__init__.py,sha256=7IduGWw-_fwx0VA6EvC_AqHF67fnnShz6YvkqCfvFx8,1317
28
28
  tpu_inference/env_override.py,sha256=pmL7lfs_rGCP92ya3wuWuudsCYeOMZ6tFZY82A4KkQc,365
29
- tpu_inference/envs.py,sha256=ugze6VdQ_hG1IxUCbcgXZq7a22fZ-Lora3V_fkFOefw,5714
29
+ tpu_inference/envs.py,sha256=MTT_Pdtd6cAcciYjv1OekEmvspaq3SYL0oR_jDkQ_aE,3948
30
30
  tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
31
- tpu_inference/tpu_info.py,sha256=3iilHRQSFjwMJwhKcuuawTm7mhwkgHbj4zi6CiAySrs,2265
32
- tpu_inference/utils.py,sha256=mHbjI8fxInPxagLsSUg-R3DzSz-X7WYNdoorPYoE3hg,10855
31
+ tpu_inference/tpu_info.py,sha256=9UohshkndR6dZpGWpWXfTD4qvIVdVgHf0yOoSEkLTrw,2276
32
+ tpu_inference/utils.py,sha256=iGPY147jP_8AKMu3g7vYTndjJJiOrK_4opA0JWtws5Q,10068
33
33
  tpu_inference/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
34
  tpu_inference/core/core_tpu.py,sha256=WDD3koE_j1QhWS2BbMA2aQOZayPZm4tYPvzL4YCX2jY,33294
35
35
  tpu_inference/core/disagg_executor.py,sha256=HZpgYMVxRxm0RQxO4l8IDYBWJ6Z3Tac6xavc5otcirc,4657
36
- tpu_inference/core/disagg_utils.py,sha256=lv8MAVoAjtcmTaenUXVokg2q3d0tzsma86UiQlQ3omY,1492
36
+ tpu_inference/core/disagg_utils.py,sha256=ufWNFWQ5n4YnZpPOtoReHlYo4dlN7AbIqCyqS4an0t4,1572
37
37
  tpu_inference/core/sched/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
38
  tpu_inference/core/sched/dp_scheduler.py,sha256=mKs8Ms46szdlBfo8hjdqis2ZKAZbcKnHAGfEr0X5R8g,22527
39
39
  tpu_inference/distributed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
40
  tpu_inference/distributed/jax_parallel_state.py,sha256=5_xCwcL03lFPUoSO_OP7hIVKpUFroW1m-jVO7R6FbUc,2223
41
- tpu_inference/distributed/tpu_connector.py,sha256=kLaTwy6BrAThJeFkd1soJ47bBo5iGp4GjUJs7xFx4Tg,29696
42
- tpu_inference/distributed/utils.py,sha256=1KIREn28Zg10O-MSUkVQMRzS09WoGc_VLGOX4QTFJac,1504
41
+ tpu_inference/distributed/tpu_connector.py,sha256=Zah46Sm5iOuh72SzXw69NxMc0MLnqsLEpe2BfDhpnqA,29731
42
+ tpu_inference/distributed/utils.py,sha256=RwFQi8G4TzN1g9RjQu0pb5JxSc_jhoIZVsFJo0uHjxo,1513
43
43
  tpu_inference/executors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- tpu_inference/executors/ray_distributed_executor.py,sha256=9CnzWb8aurH1B0tJfMHB73F-RQBGqSf5DnymetBvZ5o,16225
44
+ tpu_inference/executors/ray_distributed_executor.py,sha256=ZMuVUwmroi7UUZs3u67OsOwUIkxNDz9IszUPG20F18E,15904
45
45
  tpu_inference/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  tpu_inference/experimental/llama3_jax_stashed.py,sha256=YK1oSIfto9ALo-HB45XfSrbq9XgVbE4m2C-9zRwmSzI,10913
47
47
  tpu_inference/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -53,7 +53,7 @@ tpu_inference/kernels/flash_attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCe
53
53
  tpu_inference/kernels/flash_attention/kernel.py,sha256=n8gmAFVfchMXlyaSEj8xXJm6AadFt26edQihPRdithY,25897
54
54
  tpu_inference/kernels/fused_moe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
55
  tpu_inference/kernels/fused_moe/v1/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
- tpu_inference/kernels/fused_moe/v1/kernel.py,sha256=xVXfclgbw_3U7c5W1azDFkFDK5FolBzDN9IL0rIzLQs,62813
56
+ tpu_inference/kernels/fused_moe/v1/kernel.py,sha256=QHB0QEvC3x_6zhwz06JQpaOncQcNAhOSV92dD5tGVq8,40869
57
57
  tpu_inference/kernels/mla/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
58
  tpu_inference/kernels/mla/v1/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  tpu_inference/kernels/mla/v1/kernel.py,sha256=dw1nhpL47uQxMFOIN2kENC6aITbalT81YZLAyr1usLU,51571
@@ -67,18 +67,18 @@ tpu_inference/kernels/ragged_paged_attention/v2/kernel.py,sha256=OiQGAHhyggbp1Pe
67
67
  tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256=vGp2ZWODTbjyG9z2z0Qf_BX-wYHd5bUybnc_DtOz0nI,10995
68
68
  tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
69
69
  tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
70
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=O179Fft5KpuN5LIFx3SghWXJJUqh3Og-xqfO4Z8QXYU,57032
71
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=1ysmx7awuSZUnR7TcyUkARAvyMxNQS-9XRFMYnadZvk,61195
70
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=tlP6121yfXaukx_RQroHlHcZnbKPyyum0lAcvT0B_Pk,56132
71
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=pD1Pte3neoLAxE3I3-VyV_4FuqgCHeAHGzEjMVt0MMk,56004
72
72
  tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=k3LwduhZO85cJ-pSgnGN0c2Nn8eNeQq4eA94KUXJzMw,142198
73
73
  tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py,sha256=P3_ivi8iUz5QMU_3pgpl4Bkbmn0q0NpDtVJX39haRQA,11208
74
74
  tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=1N_ozjKboDYLteFJndWoLXNudj2z53rGXMkELa5Z9tY,1102
75
75
  tpu_inference/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
76
76
  tpu_inference/layers/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
- tpu_inference/layers/common/attention_interface.py,sha256=SQZ-1I32Jqg7GGI-z4BVibXbaitJHyTs26X3B5nBRVo,13369
77
+ tpu_inference/layers/common/attention_interface.py,sha256=CImMS8tuWgvaRY9YbGS3pY7OBnzeJ4Jla7LRFb4Xoa4,13224
78
78
  tpu_inference/layers/common/attention_metadata.py,sha256=St8ZatbY1D7xQACKJH459jMgp3oTP3AQ36mi9FZdrPU,850
79
79
  tpu_inference/layers/common/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
80
80
  tpu_inference/layers/common/quant_methods.py,sha256=mQSxZ44-QQtm22C_8ViejnP1cP2Dv6yc2YaP6oMKJeQ,185
81
- tpu_inference/layers/common/sharding.py,sha256=sjbwkDr2fP26Ob8f5cSDeDifr3eWFZMDHU4MKr7pIgQ,25217
81
+ tpu_inference/layers/common/sharding.py,sha256=wBqdkXZSWfnnH8pkJtyW2DSqmAe_V4Vxi0iMPaXq0Z0,25185
82
82
  tpu_inference/layers/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
83
  tpu_inference/layers/jax/base.py,sha256=Vhts6ZMwNCZ8LbnEXeB0rl3nHdS5hDJWX7HEa7Fl7yE,5775
84
84
  tpu_inference/layers/jax/constants.py,sha256=NcYg0zAf3ClfP7YMYdYu_F1GngOzZaIxIAHBZDunKw4,2755
@@ -102,14 +102,14 @@ tpu_inference/layers/jax/sample/sampling.py,sha256=C30KgmdOVSaagvHhbfLgVJtVQmJo8
102
102
  tpu_inference/layers/jax/sample/sampling_metadata.py,sha256=Gd835LNWfGM0NRQBVBqEv0nPwt5q9F4AdFym0CUS1fw,2561
103
103
  tpu_inference/layers/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
104
104
  tpu_inference/layers/vllm/attention.py,sha256=wbJpcgqEAuIirv5PIULbiP-ggMKjmTanbB7Dg0BVYv4,7366
105
- tpu_inference/layers/vllm/fused_moe.py,sha256=qGbQoCq-sdcZj_Q0kP6RzQk7_YvcX7FopkpLcerjNFM,17819
105
+ tpu_inference/layers/vllm/fused_moe.py,sha256=XZt2CPUz00qZzDcyfBFz6buhVzmGL1amHalHJALl9zw,18945
106
106
  tpu_inference/layers/vllm/linear_common.py,sha256=_YlJtbdaYcck_j-gFLos_k0ycktVWxT8Qo57tR2YqJ8,7749
107
- tpu_inference/layers/vllm/sharding.py,sha256=as7CF8UKTF3ToymwRY5Pi8uzwJk0P1sHPkWB5xEx3mA,9169
107
+ tpu_inference/layers/vllm/sharding.py,sha256=WTx1tF_7R99AdyE-lL7HQJ378hAafeI-JVRsugAvwn4,9177
108
108
  tpu_inference/layers/vllm/quantization/__init__.py,sha256=SEppGayBzzQ5tsXLSy99aqilkAawQwYxnv2alCg6-ZU,1777
109
109
  tpu_inference/layers/vllm/quantization/awq.py,sha256=-8ZmjGvSKJB6_JuwSctNWt8xHWq4VSvK_AK9iahlgCo,8495
110
- tpu_inference/layers/vllm/quantization/common.py,sha256=8XD64pPa077c9HThFhLFVHlDL9YBafnYwp6rp6gR44E,4432
111
- tpu_inference/layers/vllm/quantization/mxfp4.py,sha256=UT6gpMrH27CusdGUMqEvQpJg1CPvsvnqAe0GKfZdV6o,13596
112
- tpu_inference/layers/vllm/quantization/unquantized.py,sha256=YaZdO_XjT06U1gtsUgNVSF1BrFqc4sCGO0dgtprUtwM,14395
110
+ tpu_inference/layers/vllm/quantization/common.py,sha256=wm3pge6XMTMsLK7_SSdgBP0PvQzz-1mrqN2I6xMqzrc,4218
111
+ tpu_inference/layers/vllm/quantization/mxfp4.py,sha256=KwGoqIiPkd6FplGuYAKi4uX5A8MPlZqq99MVPchXyi4,11561
112
+ tpu_inference/layers/vllm/quantization/unquantized.py,sha256=Q1v1ZbSIDmaoOg97Ehv6rA5CnSf6nTP40xDBMmHHeLw,15054
113
113
  tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
114
114
  tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=6idEyy3e849fZ1UeNvc9eSHYX7e6qvohrJa_d_D9MBk,5285
115
115
  tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py,sha256=FM901QhyhJRC8CuMeICzCVVERvBHbhruRxYW0EQ570s,8820
@@ -118,57 +118,62 @@ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_ten
118
118
  tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=4y7lYgybpXszpCAtxGFhR8LDEbEoCCeo3DfUSOXxhaQ,5202
119
119
  tpu_inference/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
120
120
  tpu_inference/lora/torch_lora_ops.py,sha256=pr3N7DVfkn3ANijUC6dBoiCtIJW4fdJpKdC3zWBUsxE,3121
121
- tpu_inference/lora/torch_punica_tpu.py,sha256=qTnXZGLoOgvukSxeunO_SfpPTlkq9GlMj9H7zVYg9LE,12680
121
+ tpu_inference/lora/torch_punica_tpu.py,sha256=b27DpmIS_N5bhlIcryiENYNmPxp_cu40CGxjPW64d44,12706
122
+ tpu_inference/mock/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
123
+ tpu_inference/mock/vllm_config_utils.py,sha256=FlQshLjoHdgs3C66tYHYbKFUjbk9DhUwY-7HibZk0fI,878
124
+ tpu_inference/mock/vllm_envs.py,sha256=cCubeOhH2WeYZQFJt6W0y_IiQo0fzIWR1LCCE8i6kI4,50990
125
+ tpu_inference/mock/vllm_logger.py,sha256=vUGnN5nKT--ZvU15YCzODUM_FGiXKhcrrjDGjeN00RQ,7297
126
+ tpu_inference/mock/vllm_logging_utils.py,sha256=TEUmKj3xHiLzHBnFqAujcxH0t2hBQ04sUaho2RyORnk,486
122
127
  tpu_inference/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
123
128
  tpu_inference/models/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
124
- tpu_inference/models/common/model_loader.py,sha256=b3aigca81gMVJt42oF2aoRohQHjBBe3oK3IPblZAaUM,19996
129
+ tpu_inference/models/common/model_loader.py,sha256=VgxM2OODb0-69dexv4aNJ4g24Nrx5sj_ra4XStkhl14,18289
125
130
  tpu_inference/models/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
126
131
  tpu_inference/models/jax/deepseek_v3.py,sha256=SKOHVEC-_2NLxBnzBzbu5tu0d6FTlAEiI1EefGaO2QE,40047
127
132
  tpu_inference/models/jax/gpt_oss.py,sha256=Vw4LRB5Kp6hbA2hjZGFS8kiEqOCjf881XH2JNtu2S1I,20924
128
133
  tpu_inference/models/jax/jax_intermediate_tensor.py,sha256=Pxu1PCV5LN5X58aYVkPiohcXZIeKVim2oqvrS_cVgw4,2604
129
- tpu_inference/models/jax/llama3.py,sha256=ZiFtrpAzXTT9vAPES9UeuJInCWGbvDWs7g0_JLdCCa4,13479
134
+ tpu_inference/models/jax/llama3.py,sha256=w99DAfipGS9HyX2ZRwqyYLxC3oa0ew5eEQ6EXlMMf18,13426
130
135
  tpu_inference/models/jax/llama4.py,sha256=wf2Sp2iYViaYD5rSfv3_ryO6gYuYM5XaOyvghaP4OCY,29631
131
- tpu_inference/models/jax/llama_eagle3.py,sha256=7-U99yvBkle-FSZ3NDDI-obWSQ2Fo2OTOi1H67H4jxY,12476
132
- tpu_inference/models/jax/llama_guard_4.py,sha256=LrnU2zBWM0s4q_5dwmR--OO0V7ttltsYhrHYlBgQVIw,15275
133
- tpu_inference/models/jax/qwen2.py,sha256=SuAp7tErk8OoIRko0Vt6QSOZP_9B9r5GTfqmVfImUIo,13410
134
- tpu_inference/models/jax/qwen2_5_vl.py,sha256=WUOmqNE6fHQ8PGU85Y8Bt6-CtCC1Uubbox_9FdpDMMo,49833
135
- tpu_inference/models/jax/qwen3.py,sha256=CIZQKjZDke_LPGsLNhRCJdDTzWueUneBPAQ1blS24IM,11050
136
+ tpu_inference/models/jax/llama_eagle3.py,sha256=STUkAK6XEA7JM3i_Lx36-t5BhkAGeW_xYiq3zYhHP1A,12297
137
+ tpu_inference/models/jax/phi3.py,sha256=TpP3Nvr1myW_Qd8xNrLP1VmXtq7BuTcWNayJitskFd0,13579
138
+ tpu_inference/models/jax/qwen2.py,sha256=P_x_Qygf-nanmF8Uufk4c-qLNxP4RAk4yuqSF8VwbxE,13357
139
+ tpu_inference/models/jax/qwen2_5_vl.py,sha256=fvMgM5GfUn5EECaMbR0z37mmbCHphAT1AvWPvGkhVn4,43942
140
+ tpu_inference/models/jax/qwen3.py,sha256=lr3TIIQKmNgWFDFxwuPsVOypqBijkqrpnNCopVg4iBo,10997
136
141
  tpu_inference/models/jax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
137
142
  tpu_inference/models/jax/utils/file_utils.py,sha256=NOuSC3YFnZpf3CZgYdghbbiNYJt42zgjlEYbOZIVct4,2840
138
143
  tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=rrIrQWidkUnGilBHKNpdYh7_2BkvnAaqanXjC81GNcg,6156
139
- tpu_inference/models/jax/utils/weight_utils.py,sha256=qFU53jPHPvIcs_EOdIH80oNojpUp7GdSY2E6NZNsjvM,21376
144
+ tpu_inference/models/jax/utils/weight_utils.py,sha256=65-H8BTbyilIBMBfvWjkkW3mf4soYASbhrJFqbFKzL4,20129
140
145
  tpu_inference/models/jax/utils/quantization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
141
146
  tpu_inference/models/jax/utils/quantization/mxfp4_utils.py,sha256=boGnqJCRIOf5nedAxQ8_IUTV6Rfll10DXnRC40BeeE8,3682
142
- tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=rzAFU3OtQvg8w8ow0V15rMljAsa4SBrwOye6OI8Bty4,26530
147
+ tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=xgKoKB7AM3TYPxzVgEGLTK9ebQH2Kx8mNuO0heovkmk,26778
143
148
  tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml,sha256=d_YHPtaRJ_7PBrPijSzJGnVeoJO62tKIGqrgFqpYT1k,137
144
149
  tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml,sha256=b7SyL75HuSTj3fN9_ZLCK_CDiccL5DGq_DddGmxj_qk,170
145
150
  tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml,sha256=0Qwij71zj9k6rmrUNd8Q5df9YYfkoJ1ZkgMAHxQy81k,128
146
151
  tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml,sha256=lGec0UwwxmNPNgKPSsTsCMSXNJjhw507KMtM2NsSCMw,152
147
152
  tpu_inference/models/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
148
- tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=3EcaD_1vZuyAZBfDtm5u_qfCahQU28qR4rAUraNAFqs,12305
153
+ tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=o3oJ7Uhu-vSJEFHHifF8e0Q7dULRKJ2GRsT1qAN6PWY,12099
149
154
  tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=yxlJHPmRQIAwlb1MmHK3xfXokgIkJ-evNU4PgyoJUdg,1187
150
155
  tpu_inference/platforms/__init__.py,sha256=lQCrKddS_GcGpCbeogvz9zOZD1mQw5bBsiw8On46qFQ,74
151
- tpu_inference/platforms/tpu_platform.py,sha256=q_eACjDkJkmnrUrKQzfK6hyqGEf2OjWn16-JHXwWquY,10723
156
+ tpu_inference/platforms/tpu_platform.py,sha256=AYFr1Q7VUN76wcdgOe_wZuVIHgp2U8isBJ3iHrYqt0M,10530
152
157
  tpu_inference/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
153
158
  tpu_inference/runner/block_table.py,sha256=K3Ic8EgPM08d_C5nEN60mxoRydlaQWySAemf_8Q_qVw,4175
154
- tpu_inference/runner/compilation_manager.py,sha256=dU0Yk8f0LtRTBe2q0iB3xcMSRco_WPsj2wS6zZJ8WhY,40375
159
+ tpu_inference/runner/compilation_manager.py,sha256=yIsonouB5G0-fyVtAKuyyRXaMGNFwnX8D7q6ppQYgUI,36318
155
160
  tpu_inference/runner/input_batch.py,sha256=bx221NX2IOWzrtopss-B-2ZKW4y-U6nQpG09PjpUziw,18273
156
- tpu_inference/runner/kv_cache.py,sha256=LKOZM5o8_62KDXhhYzQl2ibifgxN89ZxHvB1NT9u3MQ,4577
157
- tpu_inference/runner/kv_cache_manager.py,sha256=N0a896CE7Zrs_d4ZSSzRdqgjV1It57RBDSIpOzkRqro,22013
161
+ tpu_inference/runner/kv_cache.py,sha256=F4dzW2d53xuxkFUn0oKzwE6VklGUeVm-QM19NVfIQDU,4577
162
+ tpu_inference/runner/kv_cache_manager.py,sha256=CJxXtdWuewJqcTBMoR70_Uvwxjtc3cK2jxe1KpI9kQc,22152
158
163
  tpu_inference/runner/lora_utils.py,sha256=B4xMCgXGJ4VNdePvn89HH3tIZ-gYsQ7Vq_YCiYIATEY,3843
159
164
  tpu_inference/runner/multimodal_manager.py,sha256=azEPdHOwz8CN11MQmorGdtrCLbFaTCxdWyuEsZTzjYM,9778
160
- tpu_inference/runner/persistent_batch_manager.py,sha256=Otu67vOTf1_HKAMZgPDDHlRvvZ3YVJdz-QderH4qOII,13263
165
+ tpu_inference/runner/persistent_batch_manager.py,sha256=KERSfKy6XjMejnbtPGI3hzoYAHJLeCxmpZVYPqBCago,11156
161
166
  tpu_inference/runner/speculative_decoding_manager.py,sha256=I3FDWKh2dn6nV8LgTGfCTwMKYnxQsTPpBIrmaJngXHs,10215
162
- tpu_inference/runner/structured_decoding_manager.py,sha256=gZQKQUFxh6xYYH9eGTdbguqk8hc2WwTrIdMMuCcbymE,3573
163
- tpu_inference/runner/tpu_runner.py,sha256=NBDKfSGShHmYpudrtGfo1hnVSQTcLpZV_nPiXEo7JPQ,79439
164
- tpu_inference/runner/utils.py,sha256=lKqL5nxGTk7ufzJRNdp4udn2bPu3jIX52W7akXgSrHc,17133
167
+ tpu_inference/runner/structured_decoding_manager.py,sha256=Y0ERPhj4olFh6Y2TxP0R1_4UIJwy7nemYA-h63YIR2U,3622
168
+ tpu_inference/runner/tpu_runner.py,sha256=3SZYn0CBA4LOaTO3GdQOxKx3HKmVcNmUEeSyzSAGyFY,73320
169
+ tpu_inference/runner/utils.py,sha256=ZnWUoNo-7INeB0mdXti1jwUOdbmxyExznOs-crRTQLk,17126
165
170
  tpu_inference/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
166
171
  tpu_inference/spec_decode/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
167
- tpu_inference/spec_decode/jax/eagle3.py,sha256=FxP0uWeQlHlgCpt1nY3FUd4lKlegKJljHyc05jJucaQ,19104
172
+ tpu_inference/spec_decode/jax/eagle3.py,sha256=A1dt-dmBttpy-5DGcL4noEDCB0OGP8Xo6MXqgJvWIo8,16593
168
173
  tpu_inference/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
169
- tpu_inference/worker/tpu_worker.py,sha256=LnZcSNxdhh0NkoWXxS5bZ0bsTMduSANehy2wELAaVsY,20672
170
- tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
171
- tpu_inference-0.0.1rc1.dist-info/METADATA,sha256=Ckyu7tcPAfxr698v8vDxUI70CyEVWLVDvUMFFcgqYYQ,5503
172
- tpu_inference-0.0.1rc1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
173
- tpu_inference-0.0.1rc1.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
174
- tpu_inference-0.0.1rc1.dist-info/RECORD,,
174
+ tpu_inference/worker/tpu_worker.py,sha256=0ZguK2BtIQjQSvyUTcUH9ENBrxt09w3CbgPoDY13Eok,14210
175
+ tpu_inference-0.11.1.dev202511180814.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
176
+ tpu_inference-0.11.1.dev202511180814.dist-info/METADATA,sha256=6dHy_ByQ0ihDNFuqyb-ZXTFczvQ8Ia54zBNTKaUPhSk,5465
177
+ tpu_inference-0.11.1.dev202511180814.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
178
+ tpu_inference-0.11.1.dev202511180814.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
179
+ tpu_inference-0.11.1.dev202511180814.dist-info/RECORD,,
@@ -1,361 +0,0 @@
1
- import re
2
- from typing import Any, List, Optional, Tuple
3
-
4
- import jax
5
- import jax.numpy as jnp
6
- import torch
7
- from flax import nnx
8
- from flax.typing import PRNGKey
9
- from jax.sharding import Mesh
10
- from jax.sharding import PartitionSpec as P
11
- from vllm.config import VllmConfig
12
-
13
- from tpu_inference.layers.jax.attention.attention import AttentionMetadata
14
- from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
15
- from tpu_inference.layers.jax.constants import KVCacheType
16
- from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
17
- from tpu_inference.layers.jax.misc import shard_put
18
- from tpu_inference.layers.jax.transformer_block import TransformerBlock
19
- from tpu_inference.logger import init_logger
20
- from tpu_inference.models.jax.utils.weight_utils import (
21
- get_param, model_weights_generator, print_param_info, reshape_params,
22
- transpose_params)
23
-
24
- logger = init_logger(__name__)
25
-
26
-
27
- class LlamaGuard4ForCausalLM(nnx.Module):
28
-
29
- def __init__(self,
30
- vllm_config: VllmConfig,
31
- rng: PRNGKey,
32
- mesh: Mesh,
33
- force_random_weights: bool = False):
34
- logger.warning(
35
- "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
36
- "Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
37
- "Multimodal inputs will fail.\n"
38
- "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
39
- assert mesh is not None
40
-
41
- self.vllm_config = vllm_config
42
- self.vllm_config.model_config.dtype = torch.bfloat16
43
- model_config = vllm_config.model_config
44
- text_config = model_config.hf_config.text_config
45
-
46
- self.mesh = mesh
47
- self.is_verbose = getattr(self.vllm_config.additional_config,
48
- "is_verbose", False)
49
-
50
- self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
51
-
52
- vocab_size = model_config.get_vocab_size()
53
- self.hidden_size = model_config.get_hidden_size()
54
-
55
- self.dtype: jnp.dtype = jnp.bfloat16
56
-
57
- self.num_layers: int = getattr(text_config, "num_layers", 48)
58
- hidden_act: str = getattr(text_config, "hidden_act", "silu")
59
-
60
- rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
61
- self.num_attention_heads = getattr(text_config, "num_attention_heads",
62
- 40)
63
- self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
64
- 8)
65
- self.head_dim = getattr(text_config, "head_dim", 128)
66
-
67
- intermediate_size = getattr(text_config, "intermediate_size", 8192)
68
-
69
- self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
70
- self.rope_scaling = getattr(text_config, "rope_scaling")
71
-
72
- self.rng = nnx.Rngs(rng)
73
-
74
- self.embedder = Embedder(
75
- vocab_size=vocab_size,
76
- hidden_size=self.hidden_size,
77
- dtype=self.dtype,
78
- vd_sharding=(('data', 'model'), None),
79
- rngs=self.rng,
80
- random_init=force_random_weights,
81
- )
82
-
83
- self.layers = []
84
-
85
- for i in range(self.num_layers):
86
- use_attention_rope = True
87
-
88
- custom_module = DenseFFW(dtype=self.dtype,
89
- hidden_act=hidden_act,
90
- hidden_size=self.hidden_size,
91
- intermediate_size=intermediate_size,
92
- random_init=force_random_weights,
93
- rngs=self.rng,
94
- df_sharding=P(None, 'model'),
95
- fd_sharding=P('model', None),
96
- activation_ffw_td=P('data', None))
97
-
98
- attn = Llama4Attention(
99
- hidden_size=self.hidden_size,
100
- dtype=self.dtype,
101
- num_attention_heads=self.num_attention_heads,
102
- num_key_value_heads=self.num_key_value_heads,
103
- head_dim=self.head_dim,
104
- rope_theta=self.rope_theta_text,
105
- rope_scaling={
106
- "scale_factor":
107
- self.rope_scaling["factor"],
108
- "low_freq_factor":
109
- self.rope_scaling["low_freq_factor"],
110
- "high_freq_factor":
111
- self.rope_scaling["high_freq_factor"],
112
- "original_max_position_embeddings":
113
- self.rope_scaling["original_max_position_embeddings"]
114
- },
115
- rngs=self.rng,
116
- rope_input_ordering="interleaved",
117
- # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
118
- kv_cache_dtype=vllm_config.cache_config.cache_dtype,
119
- temperature_tuning=True,
120
- temperature_tuning_scale=0.1,
121
- temperature_tuning_floor_scale=8192,
122
- use_qk_norm=self.use_qk_norm,
123
- attention_chunk_size=None if use_attention_rope else 8192,
124
- mesh=self.mesh,
125
- random_init=force_random_weights,
126
- activation_attention_td=('data', 'model'),
127
- activation_q_td=('data', 'model'),
128
- query_tnh=P('data', 'model', None),
129
- keyvalue_skh=P('data', 'model', None),
130
- activation_attention_out_td=('data', 'model'),
131
- attn_o_tnh=P('data', 'model', None),
132
- dnh_sharding=(None, 'model', None),
133
- dkh_sharding=(None, 'model', None),
134
- nhd_sharding=('model', None, None),
135
- )
136
-
137
- pre_attention_norm = RMSNorm(
138
- dims=self.hidden_size,
139
- random_init=force_random_weights,
140
- epsilon=rms_norm_eps,
141
- rngs=self.rng,
142
- activation_ffw_td=('data', None),
143
- with_scale=True,
144
- dtype=self.dtype,
145
- )
146
-
147
- pre_mlp_norm = RMSNorm(
148
- dims=self.hidden_size,
149
- activation_ffw_td=('data', None),
150
- epsilon=rms_norm_eps,
151
- rngs=self.rng,
152
- with_scale=True,
153
- dtype=self.dtype,
154
- random_init=force_random_weights,
155
- )
156
-
157
- block = TransformerBlock(custom_module=custom_module,
158
- attn=attn,
159
- pre_attention_norm=pre_attention_norm,
160
- pre_mlp_norm=pre_mlp_norm,
161
- use_attention_rope=use_attention_rope)
162
- self.layers.append(block)
163
-
164
- self.final_norm = RMSNorm(
165
- dims=self.hidden_size,
166
- activation_ffw_td=P(),
167
- epsilon=rms_norm_eps,
168
- rngs=self.rng,
169
- with_scale=True,
170
- dtype=self.dtype,
171
- random_init=force_random_weights,
172
- )
173
-
174
- self.lm_head = LMhead(vocab_size=vocab_size,
175
- hidden_size=self.hidden_size,
176
- dtype=self.dtype,
177
- rngs=self.rng,
178
- vd_sharding=(('data', 'model'), None),
179
- dv_sharding=(None, ('data', 'model')),
180
- random_init=force_random_weights)
181
- if self.is_verbose:
182
- self._print_model_architecture()
183
-
184
- def _print_model_architecture(self):
185
-
186
- logger.info("### Embedding ###")
187
- nnx.display(self.embedder)
188
-
189
- logger.info("\n### Layers ###")
190
- for i, layer in enumerate(self.layers):
191
- logger.info(f"\n--- Layer {i} ---")
192
- nnx.display(layer)
193
-
194
- logger.info("\n### LM Head ###")
195
- nnx.display(self.lm_head)
196
-
197
- def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
198
- self.rng = nnx.Rngs(rng)
199
-
200
- weight_loader = LlamaGuard4WeightLoader(
201
- vllm_config=self.vllm_config,
202
- hidden_size=self.hidden_size,
203
- attn_heads=self.num_attention_heads,
204
- num_key_value_heads=self.num_key_value_heads,
205
- attn_head_dim=self.head_dim)
206
- weight_loader.load_weights(self)
207
-
208
- def __call__(
209
- self,
210
- kv_caches: List[jax.Array],
211
- input_ids: jax.Array,
212
- attention_metadata: AttentionMetadata,
213
- inputs_embeds: Optional[jax.Array] = None,
214
- layer_metadata_tuple: Optional[Tuple] = None,
215
- lora_metadata: Optional[Any] = None,
216
- *args,
217
- ) -> Tuple[List[KVCacheType], jax.Array]:
218
- is_prefill = False
219
-
220
- if inputs_embeds is not None:
221
- x_TD = inputs_embeds
222
- elif input_ids is not None:
223
- x_TD = self.embedder.encode(input_ids)
224
- else:
225
- raise ValueError(
226
- "Cannot run forward pass: Both input_ids and inputs_embeds are None."
227
- )
228
-
229
- for (i, block) in enumerate(self.layers):
230
- kv_cache = kv_caches[i]
231
- new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
232
- attention_metadata)
233
- jax.block_until_ready(x_TD)
234
- kv_caches[i] = new_kv_cache
235
-
236
- final_activation_TD = self.final_norm(x_TD)
237
-
238
- return kv_caches, final_activation_TD, []
239
-
240
- def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
241
- logits_TV = jnp.dot(hidden_states,
242
- self.lm_head.input_embedding_table_DV.value)
243
- return logits_TV
244
-
245
- def get_input_embeddings(
246
- self,
247
- input_ids: jax.Array,
248
- multimodal_embeddings: Optional[List[jax.Array]] = None
249
- ) -> jax.Array:
250
- """
251
- Computes the embeddings for text input (used for input to fusion).
252
- """
253
- return self.embedder.encode(input_ids)
254
-
255
-
256
- class LlamaGuard4WeightLoader:
257
-
258
- def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
259
- num_key_value_heads, attn_head_dim):
260
- self.names_and_weights_generator = model_weights_generator(
261
- model_name_or_path=vllm_config.model_config.model,
262
- framework="flax",
263
- filter_regex="language_model",
264
- download_dir=vllm_config.load_config.download_dir)
265
- self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
266
- False)
267
- self._transpose_map = {
268
- "q_proj": (2, 0, 1),
269
- "k_proj": (2, 0, 1),
270
- "v_proj": (2, 0, 1),
271
- "o_proj": (1, 2, 0),
272
- "lm_head": (1, 0),
273
- "feed_forward.down_proj": (1, 0),
274
- "feed_forward.gate_proj": (1, 0),
275
- "feed_forward.up_proj": (1, 0),
276
- "mlp.down_proj": (1, 0),
277
- "mlp.gate_proj": (1, 0),
278
- "mlp.up_proj": (1, 0),
279
- }
280
- self._weight_shape_map = {
281
- "q_proj": (attn_heads, attn_head_dim, hidden_size),
282
- "k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
283
- "v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
284
- "o_proj": (hidden_size, attn_heads, attn_head_dim),
285
- }
286
-
287
- self._loaded_to_standardized_keys = {
288
- "language_model.model.embed_tokens.weight":
289
- "embedder.input_embedding_table_VD",
290
- "language_model.lm_head.weight":
291
- "lm_head.input_embedding_table_DV",
292
- "language_model.model.norm.weight":
293
- "final_norm.scale",
294
- "language_model.model.layers.*.input_layernorm.weight":
295
- "layers.*.pre_attention_norm.scale",
296
- "language_model.model.layers.*.post_attention_layernorm.weight":
297
- "layers.*.pre_mlp_norm.scale",
298
- "language_model.model.layers.*.self_attn.q_proj.weight":
299
- "layers.*.attn.kernel_q_proj_DNH",
300
- "language_model.model.layers.*.self_attn.k_proj.weight":
301
- "layers.*.attn.kernel_k_proj_DKH",
302
- "language_model.model.layers.*.self_attn.v_proj.weight":
303
- "layers.*.attn.kernel_v_proj_DKH",
304
- "language_model.model.layers.*.self_attn.o_proj.weight":
305
- "layers.*.attn.kernel_o_proj_NHD",
306
- "language_model.model.layers.*.feed_forward.gate_proj.weight":
307
- "layers.*.custom_module.kernel_gating_DF",
308
- "language_model.model.layers.*.feed_forward.up_proj.weight":
309
- "layers.*.custom_module.kernel_up_proj_DF",
310
- "language_model.model.layers.*.feed_forward.down_proj.weight":
311
- "layers.*.custom_module.kernel_down_proj_FD",
312
- }
313
-
314
- def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
315
- if "layer" in loaded_key:
316
- layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
317
- layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
318
- mapped_key = self._loaded_to_standardized_keys.get(
319
- layer_key, loaded_key)
320
- mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
321
- mapped_key)
322
- else:
323
- mapped_key = self._loaded_to_standardized_keys.get(
324
- loaded_key, loaded_key)
325
- return mapped_key
326
-
327
- def load_weights(self, model_for_loading: nnx.Module):
328
- model_params = nnx.state(model_for_loading)
329
- with jax.default_device(jax.devices("cpu")[0]):
330
- for loaded_name, loaded_weight in self.names_and_weights_generator:
331
- if loaded_name.endswith(".bias"):
332
- continue
333
- if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
334
- continue
335
-
336
- mapped_name = self.map_loaded_to_standardized_name(loaded_name)
337
- model_weight = get_param(model_params, mapped_name)
338
-
339
- if not loaded_name.endswith(".bias"):
340
- # For other layers, continue to use the transpose_params helper.
341
- loaded_weight = reshape_params(loaded_name, loaded_weight,
342
- self._weight_shape_map)
343
- loaded_weight = transpose_params(loaded_name,
344
- loaded_weight,
345
- self._transpose_map)
346
- if model_weight.value.shape != loaded_weight.shape:
347
- raise ValueError(
348
- f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
349
- f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
350
- )
351
- logger.debug(
352
- f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
353
- )
354
-
355
- model_weight.value = shard_put(loaded_weight,
356
- model_weight.sharding,
357
- mesh=model_for_loading.mesh)
358
- if self.is_verbose:
359
- print_param_info(model_weight, loaded_name)
360
-
361
- nnx.update(model_for_loading, model_params)