tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,426 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Implements a few utility functions for the various runners.
4
+ """
5
+ import bisect
6
+ import datetime
7
+ import functools
8
+ import json
9
+ import os
10
+ import time
11
+ from enum import Enum
12
+ from typing import Any
13
+
14
+ import jax
15
+ from jax._src.interpreters import pxla
16
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
17
+
18
+ from tpu_inference import envs
19
+ from tpu_inference.logger import init_logger
20
+ from tpu_inference.runner.input_batch import InputBatch
21
+
22
+ MIN_NUM_SEQS = 8
23
+
24
+ # These are used for determining the inference phase for a given batch in
25
+ # determine_phase_from_batch_composition_stats
26
+ # We will say that any batch who has at least 90% of its tokens scheduled for
27
+ # prefilling is in the PREFILL_HEAVY phase
28
+ PREFILL_HEAVY_RATIO_THRESHOLD = 0.9
29
+ # We will say that any batch who has at most 20% of its tokens scheduled for
30
+ # prefilling is in the DECODE_HEAVY phase
31
+ DECODE_HEAVY_RATIO_THRESHOLD = 0.2
32
+ # We will say that any batch who has between 40% and 60% of its tokens scheduled
33
+ # for prefilling is in the BALANCED phase
34
+ BALANCED_RATIO_THRESHOLD = (0.4, 0.6)
35
+ PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR = 15
36
+
37
+ logger = init_logger(__name__)
38
+
39
+
40
+ class InferencePhase(Enum):
41
+ PREFILL_HEAVY = 0
42
+ DECODE_HEAVY = 1
43
+ BALANCED = 2
44
+ AMBIGUOUS = 3
45
+
46
+
47
+ def get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
48
+ res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
49
+ return min(res, upper_limit)
50
+
51
+
52
+ def get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
53
+ # assert min_req_size is power of 2
54
+ assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
55
+ paddings: list = []
56
+ num = max(MIN_NUM_SEQS, min_req_size)
57
+ while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
58
+ paddings.append(num)
59
+ num = get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
60
+ logger.info(f"Prepared request paddings: {paddings}")
61
+ return paddings
62
+
63
+
64
+ def get_token_paddings(min_token_size: int, max_token_size: int,
65
+ padding_gap: int) -> list[int]:
66
+ """Generate a list of padding size, starting from min_token_size,
67
+ ending with a number that can cover max_token_size
68
+
69
+ If padding_gap == 0 then:
70
+ increase 2X each time (exponential)
71
+ else:
72
+ first increase the size to twice,
73
+ then increase the padding size by padding_gap.
74
+ """
75
+ # assert min_token_size is power of 2
76
+ assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
77
+ paddings = []
78
+ num = min_token_size
79
+
80
+ if padding_gap == 0:
81
+ while True:
82
+ paddings.append(num)
83
+ if num >= max_token_size:
84
+ break
85
+ num *= 2
86
+ else:
87
+ while num <= padding_gap:
88
+ paddings.append(num)
89
+ num *= 2
90
+ num //= 2
91
+ while num < max_token_size:
92
+ num += padding_gap
93
+ paddings.append(num)
94
+ logger.info(f"Prepared token paddings: {paddings}")
95
+ return paddings
96
+
97
+
98
+ def get_padded_token_len(paddings: list[int], x: int) -> int:
99
+ """Return the first element in paddings list greater or equal to x.
100
+ """
101
+ index = bisect.bisect_left(paddings, x)
102
+ assert index < len(paddings)
103
+ return paddings[index]
104
+
105
+
106
+ class LatencyTracker:
107
+
108
+ def __init__(self, name="Operation"):
109
+ self.name = name
110
+
111
+ def __enter__(self):
112
+ self.start_time = time.perf_counter()
113
+ return self
114
+
115
+ def __exit__(self, exc_type, exc_val, exc_tb):
116
+ self.end_time = time.perf_counter()
117
+ elapsed_time = self.end_time - self.start_time
118
+ logger.debug(f"Latency for '{self.name}': {elapsed_time:.3f} seconds")
119
+
120
+
121
+ class ForbidCompile:
122
+ """
123
+ A context manager to forbid JAX compilation in a specific block of code.
124
+
125
+ It works by temporarily wrapping the internal JAX caching function
126
+ `_cached_lowering_to_hlo`. If a call within the `with` block results
127
+ in a cache miss (i.e., triggers a new compilation), it raises a
128
+ RuntimeError.
129
+
130
+ Usage:
131
+ # This will raise an error because it's the first compilation.
132
+ with ForbidCompile():
133
+ jitted_func(x)
134
+
135
+ # "Warm up" the cache first.
136
+ jitted_func(x)
137
+ # This will now succeed without error.
138
+ with ForbidCompile():
139
+ jitted_func(x)
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ message="JAX compilation occurred but was forbidden in this context."
145
+ ):
146
+ self.message = message
147
+ self._original_func = None
148
+
149
+ def __enter__(self):
150
+ # Store the original function
151
+ self._original_func = pxla._cached_lowering_to_hlo
152
+ original_cached_func = self._original_func
153
+
154
+ # Create a wrapper
155
+ @functools.wraps(original_cached_func)
156
+ def wrapper(*args, **kwargs):
157
+ # Get cache statistics before the call
158
+ info_before = original_cached_func.cache_info()
159
+ misses_before = info_before.misses
160
+
161
+ # Execute the original cached function
162
+ result = original_cached_func(*args, **kwargs)
163
+
164
+ # Get cache statistics after the call
165
+ info_after = original_cached_func.cache_info()
166
+ misses_after = info_after.misses
167
+
168
+ # Check if a cache miss occurred
169
+ if misses_after > misses_before:
170
+ raise RuntimeError(self.message)
171
+
172
+ return result
173
+
174
+ # Monkey-patch the function with our wrapper
175
+ pxla._cached_lowering_to_hlo = wrapper
176
+
177
+ def __exit__(self, exc_type, exc_value, traceback):
178
+ # Restore the original function
179
+ if self._original_func:
180
+ pxla._cached_lowering_to_hlo = self._original_func
181
+ # Don't suppress any exceptions that occurred inside the 'with' block
182
+ return False
183
+
184
+
185
+ def get_batch_composition_stats(
186
+ input_batch: InputBatch, total_num_scheduled_tokens: int,
187
+ num_reqs: int, padded_total_num_scheduled_tokens: int,
188
+ scheduler_output: "VllmSchedulerOutput") -> dict:
189
+ """
190
+ Logs the total number of tokens scheduled for the batch, the number of
191
+ prefill tokens, the number of decode tokens, and the number of padded
192
+ tokens scheduled for the batch.
193
+ Args:
194
+ input_batch: The input batch.
195
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
196
+ num_reqs: The number of requests in the batch.
197
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
198
+ scheduler_output: The scheduler output.
199
+ Returns:
200
+ A string containing the total number of tokens scheduled for the batch, the number of
201
+ prefill tokens, the number of decode tokens, and the number of padded tokens scheduled for the batch.
202
+ """
203
+ num_prefill_tokens = 0
204
+ num_decode_tokens = 0
205
+
206
+ # Get the number of scheduled tokens for each request.
207
+ num_scheduled_tokens_per_req_list = []
208
+ # Get the number of tokens already processed for each request.
209
+ num_computed_tokens_per_req = input_batch.num_computed_tokens_cpu[:
210
+ num_reqs]
211
+
212
+ for i, req_id in enumerate(input_batch.req_ids[:num_reqs]):
213
+ assert req_id is not None
214
+
215
+ # This is the number of tokens to process in the current step for this request
216
+ num_scheduled_for_req = scheduler_output.num_scheduled_tokens[req_id]
217
+ num_scheduled_tokens_per_req_list.append(num_scheduled_for_req)
218
+
219
+ # This is the number of tokens already processed for this request (before this step)
220
+ num_already_computed = num_computed_tokens_per_req[i]
221
+
222
+ if num_already_computed == 0:
223
+ # Prefill
224
+ num_prefill_tokens += num_scheduled_for_req
225
+ # This means the request is ongoing
226
+ else:
227
+ if num_scheduled_for_req > 1:
228
+ # It's a multi-token request, so it's chunked prefill
229
+ num_prefill_tokens += num_scheduled_for_req
230
+ else:
231
+ # It's a single token for an ongoing request, so it's decode
232
+ num_decode_tokens += 1
233
+ return {
234
+ "total_num_scheduled_tokens": total_num_scheduled_tokens,
235
+ "num_prefill_tokens": num_prefill_tokens,
236
+ "num_decode_tokens": num_decode_tokens,
237
+ "padded_total_num_scheduled_tokens": padded_total_num_scheduled_tokens,
238
+ "num_reqs": num_reqs
239
+ }
240
+
241
+
242
+ def determine_phase_from_batch_composition_stats(
243
+ batch_composition_stats: dict[str, Any]) -> InferencePhase:
244
+ """
245
+ Determines the inference phase based on the batch composition stats.
246
+
247
+ Args:
248
+ batch_composition_stats: The batch composition stats.
249
+ This is a dict containing:
250
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
251
+ num_prefill_tokens: The number of prefill tokens.
252
+ num_decode_tokens: The number of decode tokens.
253
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
254
+ num_reqs: The number of requests in the batch.
255
+ Returns:
256
+ The inference phase enum value.
257
+ """
258
+ num_prefill_tokens = batch_composition_stats["num_prefill_tokens"]
259
+ total_num_scheduled_tokens = batch_composition_stats[
260
+ "total_num_scheduled_tokens"]
261
+ prefill_ratio_for_batch = num_prefill_tokens / total_num_scheduled_tokens
262
+ if prefill_ratio_for_batch >= PREFILL_HEAVY_RATIO_THRESHOLD:
263
+ return InferencePhase.PREFILL_HEAVY
264
+ elif prefill_ratio_for_batch <= DECODE_HEAVY_RATIO_THRESHOLD:
265
+ return InferencePhase.DECODE_HEAVY
266
+ elif prefill_ratio_for_batch >= BALANCED_RATIO_THRESHOLD[
267
+ 0] and prefill_ratio_for_batch <= BALANCED_RATIO_THRESHOLD[1]:
268
+ return InferencePhase.BALANCED
269
+ else:
270
+ return InferencePhase.AMBIGUOUS
271
+
272
+
273
+ class PhasedBasedProfiler:
274
+ """
275
+ Implements a phased-based profiler, which will profile three phases:
276
+ 1. Prefill heavy
277
+ 2. Decode heavy
278
+ 3. Balanced
279
+
280
+ A phase is determined based on the ratio of prefill tokens to total scheduled
281
+ tokens for the given batch (see `determine_phase_from_batch_composition_stats`).
282
+
283
+ Args:
284
+ profile_dir: The directory to save the profiles to.
285
+
286
+ Attributes:
287
+ profiling_n_steps_left: The number of steps left to profile for the current phase.
288
+ profile_dir_with_phase_suffix: The directory to save the profiles to.
289
+ num_steps_to_profile_for: The number of steps to profile for each phase.
290
+ profile_dir: The directory to save the profiles to.
291
+ inference_phase_seen: A dictionary that keeps track of whether a given phase has been seen.
292
+ default_profiling_options: The default profiling options.
293
+ current_phase: The current phase.
294
+ """
295
+
296
+ def __init__(self, profile_dir: str):
297
+ self.profiling_n_steps_left: int = 0
298
+ self.profile_dir_with_phase_suffix: str = None
299
+ self.num_steps_to_profile_for: int = int(
300
+ os.getenv("PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR",
301
+ PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR))
302
+ self.profile_dir: str = profile_dir
303
+ # NOTE: we purposely don't have AMBIGUOUS here
304
+ self.inference_phase_seen: dict = {
305
+ InferencePhase.PREFILL_HEAVY: False,
306
+ InferencePhase.DECODE_HEAVY: False,
307
+ InferencePhase.BALANCED: False
308
+ }
309
+ self.default_profiling_options = jax.profiler.ProfileOptions()
310
+ self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
311
+
312
+ self.current_phase: str = ""
313
+
314
+ logger.info(
315
+ "Phased-based profiler enabled. Traces will be saved to: %s",
316
+ self.profile_dir)
317
+
318
+ def _write_batch_composition_stats_to_file_helper(
319
+ self, batch_composition_stats: dict) -> None:
320
+ """
321
+ Writes the batch composition stats to a file at the given time,
322
+ e.g.: prefill_heavy/batch_composition_stats_2025_08_22_15_41_41_505018.json
323
+ """
324
+ now = datetime.datetime.now()
325
+ date_string_in_profiler_format = now.strftime("%Y_%m_%d_%H_%M_%S_%f")
326
+
327
+ with open(
328
+ os.path.join(
329
+ self.profile_dir_with_phase_suffix,
330
+ f"batch_composition_stats_{date_string_in_profiler_format}.json"
331
+ ), "w") as f:
332
+ f.write(json.dumps(batch_composition_stats) + "\n")
333
+
334
+ def _start_profiling(self, batch_composition_stats: dict) -> None:
335
+ """
336
+ Potentially starts profiling for a given unseen phase.
337
+
338
+ Args:
339
+ batch_composition_stats: The batch composition stats, which is a dict
340
+ containig:
341
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
342
+ num_prefill_tokens: The number of prefill tokens.
343
+ num_decode_tokens: The number of decode tokens.
344
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
345
+ num_reqs: The number of requests in the batch.
346
+ """
347
+ current_determined_phase = determine_phase_from_batch_composition_stats(
348
+ batch_composition_stats)
349
+ for phase, has_been_seen in self.inference_phase_seen.items():
350
+ if has_been_seen or phase != current_determined_phase:
351
+ continue
352
+
353
+ self.inference_phase_seen[phase] = True
354
+ self.profiling_n_steps_left = self.num_steps_to_profile_for
355
+
356
+ self.current_phase = phase.name.lower()
357
+
358
+ logger.info(f"Starting profiling for {self.current_phase} phase")
359
+ logger.info(f"Batch composition stats: {batch_composition_stats}")
360
+ self.profile_dir_with_phase_suffix = os.path.join(
361
+ self.profile_dir, self.current_phase)
362
+
363
+ # Create the profile subdirectory if it doesn't exist
364
+ os.makedirs(self.profile_dir_with_phase_suffix, exist_ok=True)
365
+
366
+ # Write the batch composition stats to a file to make it easier to
367
+ # align with the traces
368
+ self._write_batch_composition_stats_to_file_helper(
369
+ batch_composition_stats)
370
+
371
+ jax.profiler.start_trace(
372
+ self.profile_dir_with_phase_suffix,
373
+ profiler_options=self.default_profiling_options)
374
+ break
375
+
376
+ def _step_or_stop_profiling(self, batch_composition_stats: dict) -> None:
377
+ """
378
+ Steps the profiler or stops it if we have profiled enough steps for the
379
+ current phase.
380
+
381
+ Args:
382
+ batch_composition_stats: The batch composition stats, which is a dict
383
+ containig:
384
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
385
+ num_prefill_tokens: The number of prefill tokens.
386
+ num_decode_tokens: The number of decode tokens.
387
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
388
+ num_reqs: The number of requests in the batch.
389
+ """
390
+ # We only should decrement the profiling_n_steps_left if we are profiling
391
+ if self.current_phase != "":
392
+ self._write_batch_composition_stats_to_file_helper(
393
+ batch_composition_stats)
394
+ self.profiling_n_steps_left -= 1
395
+ if self.profiling_n_steps_left <= 0:
396
+ jax.profiler.stop_trace()
397
+ logger.info(
398
+ f"Profiling for {self.current_phase} phase finished")
399
+ self.current_phase = ""
400
+
401
+ def step(self, batch_composition_stats: dict) -> None:
402
+ """
403
+ Steps the profiler.
404
+
405
+ Args:
406
+ batch_composition_stats: The batch composition stats, which is a dict
407
+ containig:
408
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
409
+ num_prefill_tokens: The number of prefill tokens.
410
+ num_decode_tokens: The number of decode tokens.
411
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
412
+ num_reqs: The number of requests in the batch.
413
+ """
414
+ have_seen_all_phases = all(self.inference_phase_seen.values())
415
+ # We want to start profiling only after the first trial request
416
+ is_past_initial_request = batch_composition_stats[
417
+ "num_reqs"] > 1 and batch_composition_stats[
418
+ "total_num_scheduled_tokens"] > 1
419
+ if is_past_initial_request and (not have_seen_all_phases
420
+ or self.current_phase != ""):
421
+ # We haven't started profiling yet
422
+ if self.profiling_n_steps_left <= 0:
423
+ self._start_profiling(batch_composition_stats)
424
+ # We are in the middle of profiling a given phase
425
+ else:
426
+ self._step_or_stop_profiling(batch_composition_stats)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.