tpu-inference 0.11.1.dev202511180814__tar.gz → 0.11.1.dev202511270815__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (189) hide show
  1. {tpu_inference-0.11.1.dev202511180814/tpu_inference.egg-info → tpu_inference-0.11.1.dev202511270815}/PKG-INFO +3 -2
  2. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/requirements.txt +2 -1
  3. tpu_inference-0.11.1.dev202511270815/tests/kernels/fused_moe_v1_test.py +374 -0
  4. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/lora/test_layers.py +0 -6
  5. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/lora/utils.py +0 -8
  6. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/__init__.py +22 -3
  7. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/core/disagg_utils.py +6 -8
  8. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/distributed/tpu_connector.py +2 -3
  9. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/distributed/utils.py +3 -2
  10. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/envs.py +1 -1
  11. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/executors/ray_distributed_executor.py +4 -1
  12. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
  15. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/common/attention_interface.py +7 -1
  16. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/common/sharding.py +2 -1
  17. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
  20. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/unquantized.py +14 -8
  21. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/sharding.py +2 -2
  22. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/common/model_loader.py +41 -10
  24. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/llama3.py +2 -1
  25. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference-0.11.1.dev202511270815/tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/qwen2.py +2 -1
  28. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/qwen3.py +2 -1
  30. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/utils/weight_utils.py +198 -143
  31. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/vllm/vllm_model_wrapper.py +13 -6
  32. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/platforms/tpu_platform.py +15 -2
  33. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/compilation_manager.py +55 -32
  34. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/kv_cache_manager.py +9 -3
  35. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/structured_decoding_manager.py +2 -3
  36. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/tpu_runner.py +203 -102
  37. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/spec_decode/jax/eagle3.py +19 -2
  38. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/tpu_info.py +4 -3
  39. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/utils.py +5 -4
  40. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/worker/tpu_worker.py +160 -23
  41. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815/tpu_inference.egg-info}/PKG-INFO +3 -2
  42. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference.egg-info/SOURCES.txt +1 -6
  43. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference.egg-info/requires.txt +2 -1
  44. tpu_inference-0.11.1.dev202511180814/tests/kernels/fused_moe_v1_test.py +0 -105
  45. tpu_inference-0.11.1.dev202511180814/tpu_inference/mock/vllm_config_utils.py +0 -28
  46. tpu_inference-0.11.1.dev202511180814/tpu_inference/mock/vllm_envs.py +0 -1219
  47. tpu_inference-0.11.1.dev202511180814/tpu_inference/mock/vllm_logger.py +0 -212
  48. tpu_inference-0.11.1.dev202511180814/tpu_inference/mock/vllm_logging_utils.py +0 -15
  49. tpu_inference-0.11.1.dev202511180814/tpu_inference/models/jax/phi3.py +0 -376
  50. tpu_inference-0.11.1.dev202511180814/tpu_inference/worker/__init__.py +0 -0
  51. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/LICENSE +0 -0
  52. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/MANIFEST.in +0 -0
  53. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/README.md +0 -0
  54. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/pyproject.toml +0 -0
  55. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/setup.cfg +0 -0
  56. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/setup.py +0 -0
  57. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/__init__.py +0 -0
  58. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/core/__init__.py +0 -0
  59. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/core/test_core_tpu.py +0 -0
  60. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/core/test_disagg_executor.py +0 -0
  61. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/core/test_disagg_utils.py +0 -0
  62. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/core/test_dp_scheduler.py +0 -0
  63. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/core/test_init.py +0 -0
  64. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/kernels/__init__.py +0 -0
  65. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/kernels/mla_v1_test.py +0 -0
  66. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/kernels/quantized_matmul_kernel_test.py +0 -0
  67. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/kernels/ragged_kv_cache_update_v2_test.py +0 -0
  68. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/kernels/ragged_paged_attention_kernel_v2_test.py +0 -0
  69. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +0 -0
  70. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/kernels/ragged_paged_attention_kernel_v3_test.py +0 -0
  71. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/lora/__init__.py +0 -0
  72. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/lora/conftest.py +0 -0
  73. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/lora/test_bgmv.py +0 -0
  74. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/lora/test_lora.py +0 -0
  75. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/test_base.py +0 -0
  76. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/test_envs.py +0 -0
  77. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/test_quantization.py +0 -0
  78. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/test_tpu_info.py +0 -0
  79. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tests/test_utils.py +0 -0
  80. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/core/__init__.py +0 -0
  81. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/core/core_tpu.py +0 -0
  82. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/core/disagg_executor.py +0 -0
  83. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/core/sched/__init__.py +0 -0
  84. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/core/sched/dp_scheduler.py +0 -0
  85. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/distributed/__init__.py +0 -0
  86. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/distributed/jax_parallel_state.py +0 -0
  87. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/env_override.py +0 -0
  88. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/executors/__init__.py +0 -0
  89. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/experimental/__init__.py +0 -0
  90. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/experimental/llama3_jax_stashed.py +0 -0
  91. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/__init__.py +0 -0
  92. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/collectives/__init__.py +0 -0
  93. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/collectives/all_gather_matmul.py +0 -0
  94. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +0 -0
  95. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/collectives/util.py +0 -0
  96. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/flash_attention/__init__.py +0 -0
  97. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/flash_attention/kernel.py +0 -0
  98. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/fused_moe/__init__.py +0 -0
  99. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  100. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/mla/__init__.py +0 -0
  101. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/mla/v1/__init__.py +0 -0
  102. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/mla/v1/kernel.py +0 -0
  103. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  104. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/quantized_matmul/kernel.py +0 -0
  105. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +0 -0
  106. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/quantized_matmul/util.py +0 -0
  107. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  108. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  109. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +0 -0
  110. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +0 -0
  111. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +0 -0
  112. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  113. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -0
  114. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +0 -0
  115. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/kernels/ragged_paged_attention/v3/util.py +0 -0
  116. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/__init__.py +0 -0
  117. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/common/__init__.py +0 -0
  118. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/common/attention_metadata.py +0 -0
  119. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/common/binary_search.py +0 -0
  120. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/common/quant_methods.py +0 -0
  121. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/__init__.py +0 -0
  122. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/attention/__init__.py +0 -0
  123. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/attention/attention.py +0 -0
  124. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/attention/deepseek_v3_attention.py +0 -0
  125. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/attention/gpt_oss_attention.py +0 -0
  126. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/attention/llama4_attention.py +0 -0
  127. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/base.py +0 -0
  128. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/constants.py +0 -0
  129. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/layers.py +0 -0
  130. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/misc.py +0 -0
  131. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/moe/__init__.py +0 -0
  132. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/moe/deepseek_v3_moe.py +0 -0
  133. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/moe/gpt_oss_moe.py +0 -0
  134. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/moe/moe.py +0 -0
  135. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/rope.py +0 -0
  136. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/rope_interface.py +0 -0
  137. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/sample/__init__.py +0 -0
  138. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/sample/rejection_sampler.py +0 -0
  139. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/sample/sampling.py +0 -0
  140. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/sample/sampling_metadata.py +0 -0
  141. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/jax/transformer_block.py +0 -0
  142. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/__init__.py +0 -0
  143. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/attention.py +0 -0
  144. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/linear_common.py +0 -0
  145. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/__init__.py +0 -0
  146. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/awq.py +0 -0
  147. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  148. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +0 -0
  149. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +0 -0
  150. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  151. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +0 -0
  152. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +0 -0
  153. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/logger.py +0 -0
  154. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/lora/__init__.py +0 -0
  155. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/lora/torch_lora_ops.py +0 -0
  156. {tpu_inference-0.11.1.dev202511180814/tpu_inference/mock → tpu_inference-0.11.1.dev202511270815/tpu_inference/models}/__init__.py +0 -0
  157. {tpu_inference-0.11.1.dev202511180814/tpu_inference/models → tpu_inference-0.11.1.dev202511270815/tpu_inference/models/common}/__init__.py +0 -0
  158. {tpu_inference-0.11.1.dev202511180814/tpu_inference/models/common → tpu_inference-0.11.1.dev202511270815/tpu_inference/models/jax}/__init__.py +0 -0
  159. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/deepseek_v3.py +0 -0
  160. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/gpt_oss.py +0 -0
  161. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/jax_intermediate_tensor.py +0 -0
  162. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/llama4.py +0 -0
  163. {tpu_inference-0.11.1.dev202511180814/tpu_inference/models/jax → tpu_inference-0.11.1.dev202511270815/tpu_inference/models/jax/utils}/__init__.py +0 -0
  164. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/utils/file_utils.py +0 -0
  165. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/utils/multi_modal_utils.py +0 -0
  166. {tpu_inference-0.11.1.dev202511180814/tpu_inference/models/jax/utils → tpu_inference-0.11.1.dev202511270815/tpu_inference/models/jax/utils/quantization}/__init__.py +0 -0
  167. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -0
  168. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -0
  169. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -0
  170. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -0
  171. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -0
  172. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/jax/utils/quantization/quantization_utils.py +0 -0
  173. {tpu_inference-0.11.1.dev202511180814/tpu_inference/models/jax/utils/quantization → tpu_inference-0.11.1.dev202511270815/tpu_inference/models/vllm}/__init__.py +0 -0
  174. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/models/vllm/vllm_model_wrapper_context.py +0 -0
  175. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/platforms/__init__.py +0 -0
  176. {tpu_inference-0.11.1.dev202511180814/tpu_inference/models/vllm → tpu_inference-0.11.1.dev202511270815/tpu_inference/runner}/__init__.py +0 -0
  177. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/block_table.py +0 -0
  178. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/input_batch.py +0 -0
  179. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/kv_cache.py +0 -0
  180. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/lora_utils.py +0 -0
  181. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/multimodal_manager.py +0 -0
  182. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/persistent_batch_manager.py +0 -0
  183. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/speculative_decoding_manager.py +0 -0
  184. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference/runner/utils.py +0 -0
  185. {tpu_inference-0.11.1.dev202511180814/tpu_inference/runner → tpu_inference-0.11.1.dev202511270815/tpu_inference/spec_decode}/__init__.py +0 -0
  186. {tpu_inference-0.11.1.dev202511180814/tpu_inference/spec_decode → tpu_inference-0.11.1.dev202511270815/tpu_inference/spec_decode/jax}/__init__.py +0 -0
  187. {tpu_inference-0.11.1.dev202511180814/tpu_inference/spec_decode/jax → tpu_inference-0.11.1.dev202511270815/tpu_inference/worker}/__init__.py +0 -0
  188. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference.egg-info/dependency_links.txt +0 -0
  189. {tpu_inference-0.11.1.dev202511180814 → tpu_inference-0.11.1.dev202511270815}/tpu_inference.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1.dev202511180814
3
+ Version: 0.11.1.dev202511270815
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -27,10 +27,11 @@ Requires-Dist: jaxtyping
27
27
  Requires-Dist: flax==0.11.1
28
28
  Requires-Dist: torchax==0.0.7
29
29
  Requires-Dist: qwix==0.1.1
30
- Requires-Dist: torchvision==0.23.0
30
+ Requires-Dist: torchvision==0.24.0
31
31
  Requires-Dist: pathwaysutils
32
32
  Requires-Dist: parameterized
33
33
  Requires-Dist: numba==0.62.1
34
+ Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
34
35
  Dynamic: author
35
36
  Dynamic: classifier
36
37
  Dynamic: description
@@ -11,7 +11,8 @@ jaxtyping
11
11
  flax==0.11.1
12
12
  torchax==0.0.7
13
13
  qwix==0.1.1
14
- torchvision==0.23.0
14
+ torchvision==0.24.0
15
15
  pathwaysutils
16
16
  parameterized
17
17
  numba==0.62.1
18
+ runai-model-streamer[s3,gcs]==0.15.0
@@ -0,0 +1,374 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from absl.testing import absltest, parameterized
5
+ from jax._src import test_util as jtu
6
+ from jax.sharding import Mesh
7
+
8
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
9
+
10
+ jax.config.parse_flags_with_absl()
11
+
12
+
13
+ def cdiv(a, b):
14
+ assert b != 0
15
+ return (a + b - 1) // b
16
+
17
+
18
+ def align_to(x, a):
19
+ return cdiv(x, a) * a
20
+
21
+
22
+ def gen_moe_inputs(
23
+ dtype,
24
+ top_k,
25
+ num_experts,
26
+ hidden_size,
27
+ intermediate_size,
28
+ num_tokens,
29
+ *,
30
+ seed=1234,
31
+ has_bias=False,
32
+ ):
33
+ key = jax.random.key(seed)
34
+ k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
35
+
36
+ a = jax.random.normal(k0, (num_tokens, hidden_size),
37
+ dtype=jnp.float32).astype(dtype) / 10
38
+
39
+ w1 = (jax.random.normal(
40
+ k1,
41
+ (num_experts, 2, hidden_size, intermediate_size),
42
+ dtype=jnp.float32,
43
+ ) / 10).astype(dtype)
44
+ w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
45
+ dtype=jnp.float32) / 10).astype(dtype)
46
+
47
+ if has_bias:
48
+ b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
49
+ dtype=jnp.float32) / 10).astype(dtype)
50
+ b2 = (jax.random.normal(k4, (num_experts, hidden_size),
51
+ dtype=jnp.float32) / 10).astype(dtype)
52
+ else:
53
+ b1 = b2 = None
54
+
55
+ gating_output = (
56
+ jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
57
+ jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
58
+ num_tokens, num_experts) / 100)
59
+
60
+ # To generate unique top-k!
61
+ top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
62
+ minval=0,
63
+ maxval=num_experts - 1,
64
+ dtype=jnp.int32)
65
+
66
+ one_hot = (jnp.sum(
67
+ jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
68
+ axis=1,
69
+ ) * 30)
70
+
71
+ gating_output = (gating_output + one_hot).astype(dtype)
72
+
73
+ return a, w1, w2, b1, b2, gating_output
74
+
75
+
76
+ def sub_channel_quantize(x, quant_dtype, wsz=256):
77
+ """Quantizes x with sub-channel quantization on the 2nd minor."""
78
+ if jnp.issubdtype(quant_dtype, jnp.floating):
79
+ dtype_info = jnp.finfo(quant_dtype)
80
+ else:
81
+ dtype_info = jnp.iinfo(quant_dtype)
82
+ dtype_max = float(dtype_info.max)
83
+ w_lst, scale_lst = [], []
84
+ assert len(x.shape) >= 2
85
+ assert x.shape[-2] % wsz == 0
86
+ for i in range(0, x.shape[-2], wsz):
87
+ y = x[..., i:i + wsz, :]
88
+ abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
89
+ scale = (abs_max / dtype_max).astype(jnp.float32)
90
+ w = (y / scale).astype(quant_dtype)
91
+ w_lst.append(w)
92
+ scale_lst.append(scale)
93
+ return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
94
+
95
+
96
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
97
+ class MoEKernelTest(jtu.JaxTestCase):
98
+
99
+ def setUp(self):
100
+ super().setUp()
101
+ self.mesh_devices = sorted(
102
+ jax.devices(),
103
+ key=lambda x: (
104
+ x.coords[0],
105
+ (-1 if x.coords[0] % 2 else 1) * x.coords[1],
106
+ ),
107
+ )
108
+ self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
109
+ axis_names=("data", "model"))
110
+
111
+ def _test_moe(
112
+ self,
113
+ dtype,
114
+ top_k,
115
+ num_experts,
116
+ hidden_size,
117
+ intermediate_size,
118
+ num_tokens,
119
+ seed,
120
+ renormalize_topk_logits,
121
+ bt,
122
+ bf,
123
+ bd1,
124
+ bd2,
125
+ btc,
126
+ bfc,
127
+ bd1c,
128
+ bd2c,
129
+ act_fn="silu",
130
+ w_dtype=None,
131
+ subc_quant_wsz=None,
132
+ has_bias=False,
133
+ atol=2e-1,
134
+ rtol=2e-1,
135
+ ):
136
+ a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
137
+ dtype,
138
+ top_k,
139
+ num_experts,
140
+ hidden_size,
141
+ intermediate_size,
142
+ num_tokens,
143
+ seed=seed,
144
+ has_bias=has_bias,
145
+ )
146
+ w1_scale = None
147
+ w2_scale = None
148
+ if w_dtype is not None:
149
+ if subc_quant_wsz is None:
150
+ subc_quant_wsz = 256
151
+ w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
152
+ w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
153
+
154
+ actual = fused_ep_moe(
155
+ mesh=self.mesh,
156
+ tokens=a,
157
+ w1=w1,
158
+ w2=w2,
159
+ gating_output=gating_output,
160
+ top_k=top_k,
161
+ renormalize_topk_logits=renormalize_topk_logits,
162
+ act_fn=act_fn,
163
+ subc_quant_wsz=subc_quant_wsz,
164
+ w1_scale=w1_scale,
165
+ w2_scale=w2_scale,
166
+ b1=b1,
167
+ b2=b2,
168
+ bt=bt,
169
+ bf=bf,
170
+ bd1=bd1,
171
+ bd2=bd2,
172
+ btc=btc,
173
+ bfc=bfc,
174
+ bd1c=bd1c,
175
+ bd2c=bd2c,
176
+ )
177
+ expected = ref_moe(
178
+ a,
179
+ w1,
180
+ w2,
181
+ gating_output,
182
+ top_k,
183
+ b1=b1,
184
+ b2=b2,
185
+ renormalize_topk_logits=renormalize_topk_logits,
186
+ activation=act_fn,
187
+ subc_quant_wsz=subc_quant_wsz,
188
+ w1_scale=w1_scale,
189
+ w2_scale=w2_scale,
190
+ )
191
+ self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
192
+
193
+ @parameterized.product(renormalize_topk_logits=[True, False], )
194
+ def test_basic(self, renormalize_topk_logits):
195
+ dtype = jnp.bfloat16
196
+ top_k = 8
197
+ num_experts = 128
198
+ hidden_size = 1024
199
+ intermediate_size = 1024
200
+ num_tokens = 8 * 32
201
+ self._test_moe(
202
+ dtype=dtype,
203
+ top_k=top_k,
204
+ num_experts=num_experts,
205
+ hidden_size=hidden_size,
206
+ intermediate_size=intermediate_size,
207
+ num_tokens=num_tokens,
208
+ seed=1234,
209
+ renormalize_topk_logits=renormalize_topk_logits,
210
+ bt=32,
211
+ bf=1024,
212
+ bd1=1024,
213
+ bd2=1024,
214
+ btc=32,
215
+ bfc=256,
216
+ bd1c=256,
217
+ bd2c=256,
218
+ )
219
+
220
+ @parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
221
+ def test_activation(self, act_fn):
222
+ dtype = jnp.bfloat16
223
+ top_k = 8
224
+ num_experts = 128
225
+ hidden_size = 1024
226
+ intermediate_size = 1024
227
+ num_tokens = 8 * 32
228
+ self._test_moe(
229
+ dtype=dtype,
230
+ top_k=top_k,
231
+ num_experts=num_experts,
232
+ hidden_size=hidden_size,
233
+ intermediate_size=intermediate_size,
234
+ num_tokens=num_tokens,
235
+ seed=1234,
236
+ renormalize_topk_logits=True,
237
+ act_fn=act_fn,
238
+ bt=32,
239
+ bf=512,
240
+ bd1=512,
241
+ bd2=512,
242
+ btc=32,
243
+ bfc=256,
244
+ bd1c=256,
245
+ bd2c=256,
246
+ )
247
+
248
+ def test_benchmark_qwen_235(self):
249
+ num_experts = 128
250
+ top_k = 8
251
+ hidden_size = 4096
252
+ intermediate_size = 1536
253
+ dtype = jnp.bfloat16
254
+ num_tokens = 8 * 64
255
+ seed = 54321
256
+ renormalize_topk_logits = True
257
+ self._test_moe(
258
+ dtype=dtype,
259
+ top_k=top_k,
260
+ num_experts=num_experts,
261
+ hidden_size=hidden_size,
262
+ intermediate_size=intermediate_size,
263
+ num_tokens=num_tokens,
264
+ seed=seed,
265
+ renormalize_topk_logits=renormalize_topk_logits,
266
+ bt=64,
267
+ bf=768,
268
+ bd1=2048,
269
+ bd2=2048,
270
+ btc=64,
271
+ bfc=768,
272
+ bd1c=2048,
273
+ bd2c=2048,
274
+ act_fn="silu",
275
+ atol=5e-2,
276
+ rtol=5e-2,
277
+ )
278
+
279
+ def test_benchmark_qwen_30b_a3b(self):
280
+ num_experts = 128
281
+ top_k = 8
282
+ hidden_size = 2048
283
+ intermediate_size = 768
284
+ dtype = jnp.bfloat16
285
+ num_tokens = 512
286
+ seed = 54321
287
+ renormalize_topk_logits = True
288
+ self._test_moe(
289
+ dtype=dtype,
290
+ top_k=top_k,
291
+ num_experts=num_experts,
292
+ hidden_size=hidden_size,
293
+ intermediate_size=intermediate_size,
294
+ num_tokens=num_tokens,
295
+ seed=seed,
296
+ renormalize_topk_logits=renormalize_topk_logits,
297
+ bt=16,
298
+ bf=384,
299
+ bd1=512,
300
+ bd2=512,
301
+ btc=16,
302
+ bfc=384,
303
+ bd1c=256,
304
+ bd2c=256,
305
+ act_fn="silu",
306
+ atol=5e-2,
307
+ rtol=5e-2,
308
+ )
309
+
310
+ @parameterized.product(
311
+ w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
312
+ def test_sub_channel_quantization(self, w_dtype):
313
+ if w_dtype in (
314
+ jnp.float8_e5m2,
315
+ jnp.float4_e2m1fn,
316
+ ) and not jtu.is_device_tpu_at_least(version=7):
317
+ self.skipTest("Expect TPUv7+")
318
+ dtype = jnp.bfloat16
319
+ top_k = 8
320
+ num_experts = 128
321
+ hidden_size = 1024
322
+ intermediate_size = 1024
323
+ num_tokens = 8 * 32
324
+ self._test_moe(
325
+ dtype=dtype,
326
+ top_k=top_k,
327
+ num_experts=num_experts,
328
+ hidden_size=hidden_size,
329
+ intermediate_size=intermediate_size,
330
+ num_tokens=num_tokens,
331
+ seed=1234,
332
+ renormalize_topk_logits=False,
333
+ w_dtype=w_dtype,
334
+ subc_quant_wsz=256,
335
+ bt=32,
336
+ bf=1024,
337
+ bd1=1024,
338
+ bd2=1024,
339
+ btc=32,
340
+ bfc=256,
341
+ bd1c=256,
342
+ bd2c=256,
343
+ )
344
+
345
+ def test_bias(self):
346
+ dtype = jnp.bfloat16
347
+ top_k = 8
348
+ num_experts = 128
349
+ hidden_size = 1024
350
+ intermediate_size = 1024
351
+ num_tokens = 8 * 32
352
+ self._test_moe(
353
+ dtype=dtype,
354
+ top_k=top_k,
355
+ num_experts=num_experts,
356
+ hidden_size=hidden_size,
357
+ intermediate_size=intermediate_size,
358
+ num_tokens=num_tokens,
359
+ seed=1234,
360
+ renormalize_topk_logits=False,
361
+ has_bias=True,
362
+ bt=32,
363
+ bf=512,
364
+ bd1=512,
365
+ bd2=512,
366
+ btc=32,
367
+ bfc=256,
368
+ bd1c=256,
369
+ bd2c=256,
370
+ )
371
+
372
+
373
+ if __name__ == "__main__":
374
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -91,7 +91,6 @@ def populate_loras(
91
91
  index_to_id: list[Optional[int]],
92
92
  lora_layer: BaseLayerWithLoRA,
93
93
  baselayer_weights: torch.Tensor,
94
- generate_embeddings_tensor: int = 0,
95
94
  repeats: int = 1,
96
95
  ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
97
96
  """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
@@ -103,8 +102,6 @@ def populate_loras(
103
102
  lora_layer: the LoRAlayer to populate.
104
103
  baselayer_weights: the PyTorch tensor containing the layer's
105
104
  weights.
106
- generate_embeddings_tensor: whether to generate an
107
- embeddings tensor for each LoRA.
108
105
  repeats: must only be set for column parallel packed
109
106
  layers. Indicates the number of loras to compose
110
107
  together to create a single lora layer.
@@ -131,7 +128,6 @@ def populate_loras(
131
128
  baselayer_weights.device).init_random_lora(
132
129
  module_name=f"fake_{i}",
133
130
  weight=baselayer_weights,
134
- generate_embeddings_tensor=generate_embeddings_tensor,
135
131
  )
136
132
  sublora.lora_b = sublora.lora_b[(sublora_len *
137
133
  i):(sublora_len * (i + 1)), :]
@@ -147,7 +143,6 @@ def populate_loras(
147
143
  slot_idx,
148
144
  lora_a=lora.lora_a,
149
145
  lora_b=lora.lora_b,
150
- embeddings_tensor=lora.embeddings_tensor,
151
146
  )
152
147
 
153
148
  lora_dict[lora_id] = lora
@@ -546,7 +541,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
546
541
  index_to_id,
547
542
  lora_config.max_loras,
548
543
  vocab_size=512,
549
- extra_vocab_size=lora_config.lora_extra_vocab_size,
550
544
  )
551
545
  assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
552
546
  ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
@@ -24,7 +24,6 @@ class DummyLoRAManager:
24
24
  module_name: str,
25
25
  weight: torch.Tensor,
26
26
  rank: int = 8,
27
- generate_embeddings_tensor: int = 0,
28
27
  ):
29
28
  lora = LoRALayerWeights(
30
29
  module_name,
@@ -37,13 +36,6 @@ class DummyLoRAManager:
37
36
  dtype=weight.dtype,
38
37
  device=self._device),
39
38
  )
40
- if generate_embeddings_tensor:
41
- lora.embeddings_tensor = torch.rand(
42
- 5,
43
- generate_embeddings_tensor,
44
- dtype=weight.dtype,
45
- device=self._device,
46
- )
47
39
  self.set_module_lora(module_name, lora)
48
40
 
49
41
  return lora
@@ -1,21 +1,40 @@
1
- import os
2
-
3
1
  # The environment variables override should be imported before any other
4
2
  # modules to ensure that the environment variables are set before any
5
3
  # other modules are imported.
6
4
  import tpu_inference.env_override # noqa: F401
5
+ from tpu_inference import envs
7
6
  from tpu_inference import tpu_info as ti
8
7
  from tpu_inference.logger import init_logger
9
8
 
10
9
  logger = init_logger(__name__)
11
10
 
12
- if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
11
+ if "proxy" in envs.JAX_PLATFORMS:
13
12
  logger.info("Running vLLM on TPU via Pathways proxy.")
14
13
  # Must run pathwaysutils.initialize() before any JAX operations
15
14
  try:
15
+ import traceback
16
+
16
17
  import pathwaysutils
18
+ import vllm
19
+ from vllm.platforms import (resolve_current_platform_cls_qualname,
20
+ resolve_obj_by_qualname)
17
21
  pathwaysutils.initialize()
18
22
  logger.info("Module pathwaysutils is imported.")
23
+
24
+ # Pathways requires eager resolution of vllm.current_platform instead of
25
+ # lazy resolution in the normal code path. Since this part involves
26
+ # global topology discovery across multiple hosts, the platform
27
+ # resolution must happen before other components are loaded.
28
+ logger.info("Eagerly resolving vLLM current_platform for Pathways.")
29
+ platform_cls_qualname = resolve_current_platform_cls_qualname()
30
+ resolved_platform_instance = resolve_obj_by_qualname(
31
+ platform_cls_qualname)()
32
+ vllm.platforms._current_platform = resolved_platform_instance
33
+ vllm.platforms._init_trace = "".join(traceback.format_stack())
34
+ logger.info(
35
+ f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
36
+ )
37
+
19
38
  except Exception as e:
20
39
  logger.error(
21
40
  f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
@@ -1,17 +1,15 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import os
4
3
  from typing import Tuple
5
4
 
6
- PREFILL_SLICES = 'PREFILL_SLICES'
7
- DECODE_SLICES = 'DECODE_SLICES'
5
+ from tpu_inference import envs
8
6
 
9
7
 
10
8
  def is_disagg_enabled() -> bool:
11
9
  # We triggrer our code path as long as prefill slices are set. This
12
10
  # allows us to test interleave mode effectively with the code path
13
11
  # for comparison purposes.
14
- return PREFILL_SLICES in os.environ
12
+ return bool(envs.PREFILL_SLICES)
15
13
 
16
14
 
17
15
  def _parse_slices(slices_str: str) -> Tuple[int, ...]:
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
40
38
 
41
39
 
42
40
  def get_prefill_slices() -> Tuple[int, ...]:
43
- if PREFILL_SLICES not in os.environ:
41
+ if not envs.PREFILL_SLICES:
44
42
  return ()
45
- return _parse_slices(os.environ[PREFILL_SLICES])
43
+ return _parse_slices(envs.PREFILL_SLICES)
46
44
 
47
45
 
48
46
  def get_decode_slices() -> Tuple[int, ...]:
49
- if DECODE_SLICES not in os.environ:
47
+ if not envs.DECODE_SLICES:
50
48
  return ()
51
- return _parse_slices(os.environ[DECODE_SLICES])
49
+ return _parse_slices(envs.DECODE_SLICES)
@@ -60,7 +60,6 @@ D workflow:
60
60
 
61
61
  import copy
62
62
  import functools
63
- import os
64
63
  import threading
65
64
  import time
66
65
  from concurrent.futures import Future, ThreadPoolExecutor
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
86
85
  from vllm.v1.core.kv_cache_manager import KVCacheBlocks
87
86
  from vllm.v1.request import Request
88
87
 
88
+ from tpu_inference import envs
89
89
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
90
90
  get_kv_ports,
91
91
  get_kv_transfer_port, get_node_id,
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
441
441
 
442
442
  self.runner: TPUModelRunner = None
443
443
  self.mesh: Mesh = None
444
- self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
445
- "").lower() == "ray"
444
+ self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
446
445
  # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
447
446
  # The worker rank is assigned with vLLM's sorting logic, which does not work
448
447
  # for TPU host topology.
@@ -2,6 +2,7 @@ import os
2
2
 
3
3
  from vllm.utils.network_utils import get_ip
4
4
 
5
+ from tpu_inference import envs
5
6
  from tpu_inference.logger import init_logger
6
7
 
7
8
  logger = init_logger(__name__)
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
17
18
 
18
19
 
19
20
  def get_kv_ips() -> str:
20
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
21
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
21
22
  num_nodes = len(_NODES_KV_IP_PORT)
22
23
  ips = []
23
24
  for node_id in range(num_nodes):
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
28
29
 
29
30
 
30
31
  def get_kv_ports() -> str:
31
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
32
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
32
33
  num_nodes = len(_NODES_KV_IP_PORT)
33
34
  ports = []
34
35
  for node_id in range(num_nodes):
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
26
26
  environment_variables: dict[str, Callable[[], Any]] = {
27
27
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
28
28
  "JAX_PLATFORMS":
29
- lambda: os.getenv("JAX_PLATFORMS", ""),
29
+ lambda: os.getenv("JAX_PLATFORMS", "").lower(),
30
30
  # TPU accelerator type (e.g., "v5litepod-16", "v4-8")
31
31
  "TPU_ACCELERATOR_TYPE":
32
32
  lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
108
108
  ip_port = self.collective_rpc("get_node_kv_ip_port")
109
109
  for item in ip_port:
110
110
  set_node_kv_ip_port(item)
111
+ self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
112
+ self.vllm_config.ec_transfer_config is None
113
+ or not self.vllm_config.ec_transfer_config.is_ec_producer)
111
114
 
112
115
  def _initialize_ray_cluster(self) -> None:
113
116
  """Initialize the distributed cluster with Ray.
@@ -352,7 +355,7 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
352
355
  self.collective_rpc("init_worker", args=(all_kwargs, ))
353
356
  self.collective_rpc("init_device")
354
357
  if self.parallel_config.pipeline_parallel_size > 1:
355
- self._run_workers("initialize_pp_transfer_connect")
358
+ self.collective_rpc("initialize_pp_transfer_connect")
356
359
  self.collective_rpc("load_model")
357
360
 
358
361
  if self.use_ray_spmd_worker: