tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,426 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Implements a few utility functions for the various runners.
4
+ """
5
+ import bisect
6
+ import datetime
7
+ import functools
8
+ import json
9
+ import os
10
+ import time
11
+ from enum import Enum
12
+ from typing import Any
13
+
14
+ import jax
15
+ from jax._src.interpreters import pxla
16
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
17
+
18
+ from tpu_inference.logger import init_logger
19
+ from tpu_inference.runner.input_batch_jax import InputBatch
20
+
21
+ MIN_NUM_SEQS = 8
22
+
23
+ # These are used for determining the inference phase for a given batch in
24
+ # determine_phase_from_batch_composition_stats
25
+ # We will say that any batch who has at least 90% of its tokens scheduled for
26
+ # prefilling is in the PREFILL_HEAVY phase
27
+ PREFILL_HEAVY_RATIO_THRESHOLD = 0.9
28
+ # We will say that any batch who has at most 20% of its tokens scheduled for
29
+ # prefilling is in the DECODE_HEAVY phase
30
+ DECODE_HEAVY_RATIO_THRESHOLD = 0.2
31
+ # We will say that any batch who has between 40% and 60% of its tokens scheduled
32
+ # for prefilling is in the BALANCED phase
33
+ BALANCED_RATIO_THRESHOLD = (0.4, 0.6)
34
+ PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR = 15
35
+
36
+ logger = init_logger(__name__)
37
+
38
+
39
+ class InferencePhase(Enum):
40
+ PREFILL_HEAVY = 0
41
+ DECODE_HEAVY = 1
42
+ BALANCED = 2
43
+ AMBIGUOUS = 3
44
+
45
+
46
+ def get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
47
+ res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
48
+ return min(res, upper_limit)
49
+
50
+
51
+ def get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
52
+ # assert min_req_size is power of 2
53
+ assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
54
+ paddings: list = []
55
+ num = max(MIN_NUM_SEQS, min_req_size)
56
+ while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
57
+ paddings.append(num)
58
+ num = get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
59
+ logger.info(f"Prepared request paddings: {paddings}")
60
+ return paddings
61
+
62
+
63
+ def get_token_paddings(min_token_size: int, max_token_size: int,
64
+ padding_gap: int) -> list[int]:
65
+ """Generate a list of padding size, starting from min_token_size,
66
+ ending with a number that can cover max_token_size
67
+
68
+ If padding_gap == 0 then:
69
+ increase 2X each time (exponential)
70
+ else:
71
+ first increase the size to twice,
72
+ then increase the padding size by padding_gap.
73
+ """
74
+ # assert min_token_size is power of 2
75
+ assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
76
+ paddings = []
77
+ num = min_token_size
78
+
79
+ if padding_gap == 0:
80
+ while True:
81
+ paddings.append(num)
82
+ if num >= max_token_size:
83
+ break
84
+ num *= 2
85
+ else:
86
+ while num <= padding_gap:
87
+ paddings.append(num)
88
+ num *= 2
89
+ num //= 2
90
+ while num < max_token_size:
91
+ num += padding_gap
92
+ paddings.append(num)
93
+ logger.info(f"Prepared token paddings: {paddings}")
94
+ return paddings
95
+
96
+
97
+ def get_padded_token_len(paddings: list[int], x: int) -> int:
98
+ """Return the first element in paddings list greater or equal to x.
99
+ """
100
+ index = bisect.bisect_left(paddings, x)
101
+ assert index < len(paddings)
102
+ return paddings[index]
103
+
104
+
105
+ class LatencyTracker:
106
+
107
+ def __init__(self, name="Operation"):
108
+ self.name = name
109
+
110
+ def __enter__(self):
111
+ self.start_time = time.perf_counter()
112
+ return self
113
+
114
+ def __exit__(self, exc_type, exc_val, exc_tb):
115
+ self.end_time = time.perf_counter()
116
+ elapsed_time = self.end_time - self.start_time
117
+ logger.debug(f"Latency for '{self.name}': {elapsed_time:.3f} seconds")
118
+
119
+
120
+ class ForbidCompile:
121
+ """
122
+ A context manager to forbid JAX compilation in a specific block of code.
123
+
124
+ It works by temporarily wrapping the internal JAX caching function
125
+ `_cached_lowering_to_hlo`. If a call within the `with` block results
126
+ in a cache miss (i.e., triggers a new compilation), it raises a
127
+ RuntimeError.
128
+
129
+ Usage:
130
+ # This will raise an error because it's the first compilation.
131
+ with ForbidCompile():
132
+ jitted_func(x)
133
+
134
+ # "Warm up" the cache first.
135
+ jitted_func(x)
136
+ # This will now succeed without error.
137
+ with ForbidCompile():
138
+ jitted_func(x)
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ message="JAX compilation occurred but was forbidden in this context."
144
+ ):
145
+ self.message = message
146
+ self._original_func = None
147
+
148
+ def __enter__(self):
149
+ # Store the original function
150
+ self._original_func = pxla._cached_lowering_to_hlo
151
+ original_cached_func = self._original_func
152
+
153
+ # Create a wrapper
154
+ @functools.wraps(original_cached_func)
155
+ def wrapper(*args, **kwargs):
156
+ # Get cache statistics before the call
157
+ info_before = original_cached_func.cache_info()
158
+ misses_before = info_before.misses
159
+
160
+ # Execute the original cached function
161
+ result = original_cached_func(*args, **kwargs)
162
+
163
+ # Get cache statistics after the call
164
+ info_after = original_cached_func.cache_info()
165
+ misses_after = info_after.misses
166
+
167
+ # Check if a cache miss occurred
168
+ if misses_after > misses_before:
169
+ raise RuntimeError(self.message)
170
+
171
+ return result
172
+
173
+ # Monkey-patch the function with our wrapper
174
+ pxla._cached_lowering_to_hlo = wrapper
175
+
176
+ def __exit__(self, exc_type, exc_value, traceback):
177
+ # Restore the original function
178
+ if self._original_func:
179
+ pxla._cached_lowering_to_hlo = self._original_func
180
+ # Don't suppress any exceptions that occurred inside the 'with' block
181
+ return False
182
+
183
+
184
+ def get_batch_composition_stats(
185
+ input_batch: InputBatch, total_num_scheduled_tokens: int,
186
+ num_reqs: int, padded_total_num_scheduled_tokens: int,
187
+ scheduler_output: "VllmSchedulerOutput") -> dict:
188
+ """
189
+ Logs the total number of tokens scheduled for the batch, the number of
190
+ prefill tokens, the number of decode tokens, and the number of padded
191
+ tokens scheduled for the batch.
192
+ Args:
193
+ input_batch: The input batch.
194
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
195
+ num_reqs: The number of requests in the batch.
196
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
197
+ scheduler_output: The scheduler output.
198
+ Returns:
199
+ A string containing the total number of tokens scheduled for the batch, the number of
200
+ prefill tokens, the number of decode tokens, and the number of padded tokens scheduled for the batch.
201
+ """
202
+ num_prefill_tokens = 0
203
+ num_decode_tokens = 0
204
+
205
+ # Get the number of scheduled tokens for each request.
206
+ num_scheduled_tokens_per_req_list = []
207
+ # Get the number of tokens already processed for each request.
208
+ num_computed_tokens_per_req = input_batch.num_computed_tokens_cpu[:
209
+ num_reqs]
210
+
211
+ for i, req_id in enumerate(input_batch.req_ids[:num_reqs]):
212
+ assert req_id is not None
213
+
214
+ # This is the number of tokens to process in the current step for this request
215
+ num_scheduled_for_req = scheduler_output.num_scheduled_tokens[req_id]
216
+ num_scheduled_tokens_per_req_list.append(num_scheduled_for_req)
217
+
218
+ # This is the number of tokens already processed for this request (before this step)
219
+ num_already_computed = num_computed_tokens_per_req[i]
220
+
221
+ if num_already_computed == 0:
222
+ # Prefill
223
+ num_prefill_tokens += num_scheduled_for_req
224
+ # This means the request is ongoing
225
+ else:
226
+ if num_scheduled_for_req > 1:
227
+ # It's a multi-token request, so it's chunked prefill
228
+ num_prefill_tokens += num_scheduled_for_req
229
+ else:
230
+ # It's a single token for an ongoing request, so it's decode
231
+ num_decode_tokens += 1
232
+ return {
233
+ "total_num_scheduled_tokens": total_num_scheduled_tokens,
234
+ "num_prefill_tokens": num_prefill_tokens,
235
+ "num_decode_tokens": num_decode_tokens,
236
+ "padded_total_num_scheduled_tokens": padded_total_num_scheduled_tokens,
237
+ "num_reqs": num_reqs
238
+ }
239
+
240
+
241
+ def determine_phase_from_batch_composition_stats(
242
+ batch_composition_stats: dict[str, Any]) -> InferencePhase:
243
+ """
244
+ Determines the inference phase based on the batch composition stats.
245
+
246
+ Args:
247
+ batch_composition_stats: The batch composition stats.
248
+ This is a dict containing:
249
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
250
+ num_prefill_tokens: The number of prefill tokens.
251
+ num_decode_tokens: The number of decode tokens.
252
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
253
+ num_reqs: The number of requests in the batch.
254
+ Returns:
255
+ The inference phase enum value.
256
+ """
257
+ num_prefill_tokens = batch_composition_stats["num_prefill_tokens"]
258
+ total_num_scheduled_tokens = batch_composition_stats[
259
+ "total_num_scheduled_tokens"]
260
+ prefill_ratio_for_batch = num_prefill_tokens / total_num_scheduled_tokens
261
+ if prefill_ratio_for_batch >= PREFILL_HEAVY_RATIO_THRESHOLD:
262
+ return InferencePhase.PREFILL_HEAVY
263
+ elif prefill_ratio_for_batch <= DECODE_HEAVY_RATIO_THRESHOLD:
264
+ return InferencePhase.DECODE_HEAVY
265
+ elif prefill_ratio_for_batch >= BALANCED_RATIO_THRESHOLD[
266
+ 0] and prefill_ratio_for_batch <= BALANCED_RATIO_THRESHOLD[1]:
267
+ return InferencePhase.BALANCED
268
+ else:
269
+ return InferencePhase.AMBIGUOUS
270
+
271
+
272
+ class PhasedBasedProfiler:
273
+ """
274
+ Implements a phased-based profiler, which will profile three phases:
275
+ 1. Prefill heavy
276
+ 2. Decode heavy
277
+ 3. Balanced
278
+
279
+ A phase is determined based on the ratio of prefill tokens to total scheduled
280
+ tokens for the given batch (see `determine_phase_from_batch_composition_stats`).
281
+
282
+ Args:
283
+ profile_dir: The directory to save the profiles to.
284
+
285
+ Attributes:
286
+ profiling_n_steps_left: The number of steps left to profile for the current phase.
287
+ profile_dir_with_phase_suffix: The directory to save the profiles to.
288
+ num_steps_to_profile_for: The number of steps to profile for each phase.
289
+ profile_dir: The directory to save the profiles to.
290
+ inference_phase_seen: A dictionary that keeps track of whether a given phase has been seen.
291
+ default_profiling_options: The default profiling options.
292
+ current_phase: The current phase.
293
+ """
294
+
295
+ def __init__(self, profile_dir: str):
296
+ self.profiling_n_steps_left: int = 0
297
+ self.profile_dir_with_phase_suffix: str = None
298
+ self.num_steps_to_profile_for: int = int(
299
+ os.getenv("PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR",
300
+ PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR))
301
+ self.profile_dir: str = profile_dir
302
+ # NOTE: we purposely don't have AMBIGUOUS here
303
+ self.inference_phase_seen: dict = {
304
+ InferencePhase.PREFILL_HEAVY: False,
305
+ InferencePhase.DECODE_HEAVY: False,
306
+ InferencePhase.BALANCED: False
307
+ }
308
+ self.default_profiling_options = jax.profiler.ProfileOptions()
309
+ self.default_profiling_options.python_tracer_level = os.getenv(
310
+ "PYTHON_TRACER_LEVEL", 0)
311
+
312
+ self.current_phase: str = ""
313
+
314
+ logger.info(
315
+ "Phased-based profiler enabled. Traces will be saved to: %s",
316
+ self.profile_dir)
317
+
318
+ def _write_batch_composition_stats_to_file_helper(
319
+ self, batch_composition_stats: dict) -> None:
320
+ """
321
+ Writes the batch composition stats to a file at the given time,
322
+ e.g.: prefill_heavy/batch_composition_stats_2025_08_22_15_41_41_505018.json
323
+ """
324
+ now = datetime.datetime.now()
325
+ date_string_in_profiler_format = now.strftime("%Y_%m_%d_%H_%M_%S_%f")
326
+
327
+ with open(
328
+ os.path.join(
329
+ self.profile_dir_with_phase_suffix,
330
+ f"batch_composition_stats_{date_string_in_profiler_format}.json"
331
+ ), "w") as f:
332
+ f.write(json.dumps(batch_composition_stats) + "\n")
333
+
334
+ def _start_profiling(self, batch_composition_stats: dict) -> None:
335
+ """
336
+ Potentially starts profiling for a given unseen phase.
337
+
338
+ Args:
339
+ batch_composition_stats: The batch composition stats, which is a dict
340
+ containig:
341
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
342
+ num_prefill_tokens: The number of prefill tokens.
343
+ num_decode_tokens: The number of decode tokens.
344
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
345
+ num_reqs: The number of requests in the batch.
346
+ """
347
+ current_determined_phase = determine_phase_from_batch_composition_stats(
348
+ batch_composition_stats)
349
+ for phase, has_been_seen in self.inference_phase_seen.items():
350
+ if has_been_seen or phase != current_determined_phase:
351
+ continue
352
+
353
+ self.inference_phase_seen[phase] = True
354
+ self.profiling_n_steps_left = self.num_steps_to_profile_for
355
+
356
+ self.current_phase = phase.name.lower()
357
+
358
+ logger.info(f"Starting profiling for {self.current_phase} phase")
359
+ logger.info(f"Batch composition stats: {batch_composition_stats}")
360
+ self.profile_dir_with_phase_suffix = os.path.join(
361
+ self.profile_dir, self.current_phase)
362
+
363
+ # Create the profile subdirectory if it doesn't exist
364
+ os.makedirs(self.profile_dir_with_phase_suffix, exist_ok=True)
365
+
366
+ # Write the batch composition stats to a file to make it easier to
367
+ # align with the traces
368
+ self._write_batch_composition_stats_to_file_helper(
369
+ batch_composition_stats)
370
+
371
+ jax.profiler.start_trace(
372
+ self.profile_dir_with_phase_suffix,
373
+ profiler_options=self.default_profiling_options)
374
+ break
375
+
376
+ def _step_or_stop_profiling(self, batch_composition_stats: dict) -> None:
377
+ """
378
+ Steps the profiler or stops it if we have profiled enough steps for the
379
+ current phase.
380
+
381
+ Args:
382
+ batch_composition_stats: The batch composition stats, which is a dict
383
+ containig:
384
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
385
+ num_prefill_tokens: The number of prefill tokens.
386
+ num_decode_tokens: The number of decode tokens.
387
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
388
+ num_reqs: The number of requests in the batch.
389
+ """
390
+ # We only should decrement the profiling_n_steps_left if we are profiling
391
+ if self.current_phase != "":
392
+ self._write_batch_composition_stats_to_file_helper(
393
+ batch_composition_stats)
394
+ self.profiling_n_steps_left -= 1
395
+ if self.profiling_n_steps_left <= 0:
396
+ jax.profiler.stop_trace()
397
+ logger.info(
398
+ f"Profiling for {self.current_phase} phase finished")
399
+ self.current_phase = ""
400
+
401
+ def step(self, batch_composition_stats: dict) -> None:
402
+ """
403
+ Steps the profiler.
404
+
405
+ Args:
406
+ batch_composition_stats: The batch composition stats, which is a dict
407
+ containig:
408
+ total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
409
+ num_prefill_tokens: The number of prefill tokens.
410
+ num_decode_tokens: The number of decode tokens.
411
+ padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
412
+ num_reqs: The number of requests in the batch.
413
+ """
414
+ have_seen_all_phases = all(self.inference_phase_seen.values())
415
+ # We want to start profiling only after the first trial request
416
+ is_past_initial_request = batch_composition_stats[
417
+ "num_reqs"] >= 1 and batch_composition_stats[
418
+ "total_num_scheduled_tokens"] > 1
419
+ if is_past_initial_request and (not have_seen_all_phases
420
+ or self.current_phase != ""):
421
+ # We haven't started profiling yet
422
+ if self.profiling_n_steps_left <= 0:
423
+ self._start_profiling(batch_composition_stats)
424
+ # We are in the middle of profiling a given phase
425
+ else:
426
+ self._step_or_stop_profiling(batch_composition_stats)
File without changes
File without changes