tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""
|
|
3
|
+
Implements a few utility functions for the various runners.
|
|
4
|
+
"""
|
|
5
|
+
import bisect
|
|
6
|
+
import datetime
|
|
7
|
+
import functools
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import jax
|
|
15
|
+
from jax._src.interpreters import pxla
|
|
16
|
+
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
17
|
+
|
|
18
|
+
from tpu_inference.logger import init_logger
|
|
19
|
+
from tpu_inference.runner.input_batch_jax import InputBatch
|
|
20
|
+
|
|
21
|
+
MIN_NUM_SEQS = 8
|
|
22
|
+
|
|
23
|
+
# These are used for determining the inference phase for a given batch in
|
|
24
|
+
# determine_phase_from_batch_composition_stats
|
|
25
|
+
# We will say that any batch who has at least 90% of its tokens scheduled for
|
|
26
|
+
# prefilling is in the PREFILL_HEAVY phase
|
|
27
|
+
PREFILL_HEAVY_RATIO_THRESHOLD = 0.9
|
|
28
|
+
# We will say that any batch who has at most 20% of its tokens scheduled for
|
|
29
|
+
# prefilling is in the DECODE_HEAVY phase
|
|
30
|
+
DECODE_HEAVY_RATIO_THRESHOLD = 0.2
|
|
31
|
+
# We will say that any batch who has between 40% and 60% of its tokens scheduled
|
|
32
|
+
# for prefilling is in the BALANCED phase
|
|
33
|
+
BALANCED_RATIO_THRESHOLD = (0.4, 0.6)
|
|
34
|
+
PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR = 15
|
|
35
|
+
|
|
36
|
+
logger = init_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class InferencePhase(Enum):
|
|
40
|
+
PREFILL_HEAVY = 0
|
|
41
|
+
DECODE_HEAVY = 1
|
|
42
|
+
BALANCED = 2
|
|
43
|
+
AMBIGUOUS = 3
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
|
|
47
|
+
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
|
|
48
|
+
return min(res, upper_limit)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
|
|
52
|
+
# assert min_req_size is power of 2
|
|
53
|
+
assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
|
|
54
|
+
paddings: list = []
|
|
55
|
+
num = max(MIN_NUM_SEQS, min_req_size)
|
|
56
|
+
while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
|
|
57
|
+
paddings.append(num)
|
|
58
|
+
num = get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
|
|
59
|
+
logger.info(f"Prepared request paddings: {paddings}")
|
|
60
|
+
return paddings
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_token_paddings(min_token_size: int, max_token_size: int,
|
|
64
|
+
padding_gap: int) -> list[int]:
|
|
65
|
+
"""Generate a list of padding size, starting from min_token_size,
|
|
66
|
+
ending with a number that can cover max_token_size
|
|
67
|
+
|
|
68
|
+
If padding_gap == 0 then:
|
|
69
|
+
increase 2X each time (exponential)
|
|
70
|
+
else:
|
|
71
|
+
first increase the size to twice,
|
|
72
|
+
then increase the padding size by padding_gap.
|
|
73
|
+
"""
|
|
74
|
+
# assert min_token_size is power of 2
|
|
75
|
+
assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
|
|
76
|
+
paddings = []
|
|
77
|
+
num = min_token_size
|
|
78
|
+
|
|
79
|
+
if padding_gap == 0:
|
|
80
|
+
while True:
|
|
81
|
+
paddings.append(num)
|
|
82
|
+
if num >= max_token_size:
|
|
83
|
+
break
|
|
84
|
+
num *= 2
|
|
85
|
+
else:
|
|
86
|
+
while num <= padding_gap:
|
|
87
|
+
paddings.append(num)
|
|
88
|
+
num *= 2
|
|
89
|
+
num //= 2
|
|
90
|
+
while num < max_token_size:
|
|
91
|
+
num += padding_gap
|
|
92
|
+
paddings.append(num)
|
|
93
|
+
logger.info(f"Prepared token paddings: {paddings}")
|
|
94
|
+
return paddings
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_padded_token_len(paddings: list[int], x: int) -> int:
|
|
98
|
+
"""Return the first element in paddings list greater or equal to x.
|
|
99
|
+
"""
|
|
100
|
+
index = bisect.bisect_left(paddings, x)
|
|
101
|
+
assert index < len(paddings)
|
|
102
|
+
return paddings[index]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class LatencyTracker:
|
|
106
|
+
|
|
107
|
+
def __init__(self, name="Operation"):
|
|
108
|
+
self.name = name
|
|
109
|
+
|
|
110
|
+
def __enter__(self):
|
|
111
|
+
self.start_time = time.perf_counter()
|
|
112
|
+
return self
|
|
113
|
+
|
|
114
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
115
|
+
self.end_time = time.perf_counter()
|
|
116
|
+
elapsed_time = self.end_time - self.start_time
|
|
117
|
+
logger.debug(f"Latency for '{self.name}': {elapsed_time:.3f} seconds")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ForbidCompile:
|
|
121
|
+
"""
|
|
122
|
+
A context manager to forbid JAX compilation in a specific block of code.
|
|
123
|
+
|
|
124
|
+
It works by temporarily wrapping the internal JAX caching function
|
|
125
|
+
`_cached_lowering_to_hlo`. If a call within the `with` block results
|
|
126
|
+
in a cache miss (i.e., triggers a new compilation), it raises a
|
|
127
|
+
RuntimeError.
|
|
128
|
+
|
|
129
|
+
Usage:
|
|
130
|
+
# This will raise an error because it's the first compilation.
|
|
131
|
+
with ForbidCompile():
|
|
132
|
+
jitted_func(x)
|
|
133
|
+
|
|
134
|
+
# "Warm up" the cache first.
|
|
135
|
+
jitted_func(x)
|
|
136
|
+
# This will now succeed without error.
|
|
137
|
+
with ForbidCompile():
|
|
138
|
+
jitted_func(x)
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
message="JAX compilation occurred but was forbidden in this context."
|
|
144
|
+
):
|
|
145
|
+
self.message = message
|
|
146
|
+
self._original_func = None
|
|
147
|
+
|
|
148
|
+
def __enter__(self):
|
|
149
|
+
# Store the original function
|
|
150
|
+
self._original_func = pxla._cached_lowering_to_hlo
|
|
151
|
+
original_cached_func = self._original_func
|
|
152
|
+
|
|
153
|
+
# Create a wrapper
|
|
154
|
+
@functools.wraps(original_cached_func)
|
|
155
|
+
def wrapper(*args, **kwargs):
|
|
156
|
+
# Get cache statistics before the call
|
|
157
|
+
info_before = original_cached_func.cache_info()
|
|
158
|
+
misses_before = info_before.misses
|
|
159
|
+
|
|
160
|
+
# Execute the original cached function
|
|
161
|
+
result = original_cached_func(*args, **kwargs)
|
|
162
|
+
|
|
163
|
+
# Get cache statistics after the call
|
|
164
|
+
info_after = original_cached_func.cache_info()
|
|
165
|
+
misses_after = info_after.misses
|
|
166
|
+
|
|
167
|
+
# Check if a cache miss occurred
|
|
168
|
+
if misses_after > misses_before:
|
|
169
|
+
raise RuntimeError(self.message)
|
|
170
|
+
|
|
171
|
+
return result
|
|
172
|
+
|
|
173
|
+
# Monkey-patch the function with our wrapper
|
|
174
|
+
pxla._cached_lowering_to_hlo = wrapper
|
|
175
|
+
|
|
176
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
177
|
+
# Restore the original function
|
|
178
|
+
if self._original_func:
|
|
179
|
+
pxla._cached_lowering_to_hlo = self._original_func
|
|
180
|
+
# Don't suppress any exceptions that occurred inside the 'with' block
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def get_batch_composition_stats(
|
|
185
|
+
input_batch: InputBatch, total_num_scheduled_tokens: int,
|
|
186
|
+
num_reqs: int, padded_total_num_scheduled_tokens: int,
|
|
187
|
+
scheduler_output: "VllmSchedulerOutput") -> dict:
|
|
188
|
+
"""
|
|
189
|
+
Logs the total number of tokens scheduled for the batch, the number of
|
|
190
|
+
prefill tokens, the number of decode tokens, and the number of padded
|
|
191
|
+
tokens scheduled for the batch.
|
|
192
|
+
Args:
|
|
193
|
+
input_batch: The input batch.
|
|
194
|
+
total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
|
|
195
|
+
num_reqs: The number of requests in the batch.
|
|
196
|
+
padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
|
|
197
|
+
scheduler_output: The scheduler output.
|
|
198
|
+
Returns:
|
|
199
|
+
A string containing the total number of tokens scheduled for the batch, the number of
|
|
200
|
+
prefill tokens, the number of decode tokens, and the number of padded tokens scheduled for the batch.
|
|
201
|
+
"""
|
|
202
|
+
num_prefill_tokens = 0
|
|
203
|
+
num_decode_tokens = 0
|
|
204
|
+
|
|
205
|
+
# Get the number of scheduled tokens for each request.
|
|
206
|
+
num_scheduled_tokens_per_req_list = []
|
|
207
|
+
# Get the number of tokens already processed for each request.
|
|
208
|
+
num_computed_tokens_per_req = input_batch.num_computed_tokens_cpu[:
|
|
209
|
+
num_reqs]
|
|
210
|
+
|
|
211
|
+
for i, req_id in enumerate(input_batch.req_ids[:num_reqs]):
|
|
212
|
+
assert req_id is not None
|
|
213
|
+
|
|
214
|
+
# This is the number of tokens to process in the current step for this request
|
|
215
|
+
num_scheduled_for_req = scheduler_output.num_scheduled_tokens[req_id]
|
|
216
|
+
num_scheduled_tokens_per_req_list.append(num_scheduled_for_req)
|
|
217
|
+
|
|
218
|
+
# This is the number of tokens already processed for this request (before this step)
|
|
219
|
+
num_already_computed = num_computed_tokens_per_req[i]
|
|
220
|
+
|
|
221
|
+
if num_already_computed == 0:
|
|
222
|
+
# Prefill
|
|
223
|
+
num_prefill_tokens += num_scheduled_for_req
|
|
224
|
+
# This means the request is ongoing
|
|
225
|
+
else:
|
|
226
|
+
if num_scheduled_for_req > 1:
|
|
227
|
+
# It's a multi-token request, so it's chunked prefill
|
|
228
|
+
num_prefill_tokens += num_scheduled_for_req
|
|
229
|
+
else:
|
|
230
|
+
# It's a single token for an ongoing request, so it's decode
|
|
231
|
+
num_decode_tokens += 1
|
|
232
|
+
return {
|
|
233
|
+
"total_num_scheduled_tokens": total_num_scheduled_tokens,
|
|
234
|
+
"num_prefill_tokens": num_prefill_tokens,
|
|
235
|
+
"num_decode_tokens": num_decode_tokens,
|
|
236
|
+
"padded_total_num_scheduled_tokens": padded_total_num_scheduled_tokens,
|
|
237
|
+
"num_reqs": num_reqs
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def determine_phase_from_batch_composition_stats(
|
|
242
|
+
batch_composition_stats: dict[str, Any]) -> InferencePhase:
|
|
243
|
+
"""
|
|
244
|
+
Determines the inference phase based on the batch composition stats.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
batch_composition_stats: The batch composition stats.
|
|
248
|
+
This is a dict containing:
|
|
249
|
+
total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
|
|
250
|
+
num_prefill_tokens: The number of prefill tokens.
|
|
251
|
+
num_decode_tokens: The number of decode tokens.
|
|
252
|
+
padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
|
|
253
|
+
num_reqs: The number of requests in the batch.
|
|
254
|
+
Returns:
|
|
255
|
+
The inference phase enum value.
|
|
256
|
+
"""
|
|
257
|
+
num_prefill_tokens = batch_composition_stats["num_prefill_tokens"]
|
|
258
|
+
total_num_scheduled_tokens = batch_composition_stats[
|
|
259
|
+
"total_num_scheduled_tokens"]
|
|
260
|
+
prefill_ratio_for_batch = num_prefill_tokens / total_num_scheduled_tokens
|
|
261
|
+
if prefill_ratio_for_batch >= PREFILL_HEAVY_RATIO_THRESHOLD:
|
|
262
|
+
return InferencePhase.PREFILL_HEAVY
|
|
263
|
+
elif prefill_ratio_for_batch <= DECODE_HEAVY_RATIO_THRESHOLD:
|
|
264
|
+
return InferencePhase.DECODE_HEAVY
|
|
265
|
+
elif prefill_ratio_for_batch >= BALANCED_RATIO_THRESHOLD[
|
|
266
|
+
0] and prefill_ratio_for_batch <= BALANCED_RATIO_THRESHOLD[1]:
|
|
267
|
+
return InferencePhase.BALANCED
|
|
268
|
+
else:
|
|
269
|
+
return InferencePhase.AMBIGUOUS
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class PhasedBasedProfiler:
|
|
273
|
+
"""
|
|
274
|
+
Implements a phased-based profiler, which will profile three phases:
|
|
275
|
+
1. Prefill heavy
|
|
276
|
+
2. Decode heavy
|
|
277
|
+
3. Balanced
|
|
278
|
+
|
|
279
|
+
A phase is determined based on the ratio of prefill tokens to total scheduled
|
|
280
|
+
tokens for the given batch (see `determine_phase_from_batch_composition_stats`).
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
profile_dir: The directory to save the profiles to.
|
|
284
|
+
|
|
285
|
+
Attributes:
|
|
286
|
+
profiling_n_steps_left: The number of steps left to profile for the current phase.
|
|
287
|
+
profile_dir_with_phase_suffix: The directory to save the profiles to.
|
|
288
|
+
num_steps_to_profile_for: The number of steps to profile for each phase.
|
|
289
|
+
profile_dir: The directory to save the profiles to.
|
|
290
|
+
inference_phase_seen: A dictionary that keeps track of whether a given phase has been seen.
|
|
291
|
+
default_profiling_options: The default profiling options.
|
|
292
|
+
current_phase: The current phase.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
def __init__(self, profile_dir: str):
|
|
296
|
+
self.profiling_n_steps_left: int = 0
|
|
297
|
+
self.profile_dir_with_phase_suffix: str = None
|
|
298
|
+
self.num_steps_to_profile_for: int = int(
|
|
299
|
+
os.getenv("PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR",
|
|
300
|
+
PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR))
|
|
301
|
+
self.profile_dir: str = profile_dir
|
|
302
|
+
# NOTE: we purposely don't have AMBIGUOUS here
|
|
303
|
+
self.inference_phase_seen: dict = {
|
|
304
|
+
InferencePhase.PREFILL_HEAVY: False,
|
|
305
|
+
InferencePhase.DECODE_HEAVY: False,
|
|
306
|
+
InferencePhase.BALANCED: False
|
|
307
|
+
}
|
|
308
|
+
self.default_profiling_options = jax.profiler.ProfileOptions()
|
|
309
|
+
self.default_profiling_options.python_tracer_level = os.getenv(
|
|
310
|
+
"PYTHON_TRACER_LEVEL", 0)
|
|
311
|
+
|
|
312
|
+
self.current_phase: str = ""
|
|
313
|
+
|
|
314
|
+
logger.info(
|
|
315
|
+
"Phased-based profiler enabled. Traces will be saved to: %s",
|
|
316
|
+
self.profile_dir)
|
|
317
|
+
|
|
318
|
+
def _write_batch_composition_stats_to_file_helper(
|
|
319
|
+
self, batch_composition_stats: dict) -> None:
|
|
320
|
+
"""
|
|
321
|
+
Writes the batch composition stats to a file at the given time,
|
|
322
|
+
e.g.: prefill_heavy/batch_composition_stats_2025_08_22_15_41_41_505018.json
|
|
323
|
+
"""
|
|
324
|
+
now = datetime.datetime.now()
|
|
325
|
+
date_string_in_profiler_format = now.strftime("%Y_%m_%d_%H_%M_%S_%f")
|
|
326
|
+
|
|
327
|
+
with open(
|
|
328
|
+
os.path.join(
|
|
329
|
+
self.profile_dir_with_phase_suffix,
|
|
330
|
+
f"batch_composition_stats_{date_string_in_profiler_format}.json"
|
|
331
|
+
), "w") as f:
|
|
332
|
+
f.write(json.dumps(batch_composition_stats) + "\n")
|
|
333
|
+
|
|
334
|
+
def _start_profiling(self, batch_composition_stats: dict) -> None:
|
|
335
|
+
"""
|
|
336
|
+
Potentially starts profiling for a given unseen phase.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
batch_composition_stats: The batch composition stats, which is a dict
|
|
340
|
+
containig:
|
|
341
|
+
total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
|
|
342
|
+
num_prefill_tokens: The number of prefill tokens.
|
|
343
|
+
num_decode_tokens: The number of decode tokens.
|
|
344
|
+
padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
|
|
345
|
+
num_reqs: The number of requests in the batch.
|
|
346
|
+
"""
|
|
347
|
+
current_determined_phase = determine_phase_from_batch_composition_stats(
|
|
348
|
+
batch_composition_stats)
|
|
349
|
+
for phase, has_been_seen in self.inference_phase_seen.items():
|
|
350
|
+
if has_been_seen or phase != current_determined_phase:
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
self.inference_phase_seen[phase] = True
|
|
354
|
+
self.profiling_n_steps_left = self.num_steps_to_profile_for
|
|
355
|
+
|
|
356
|
+
self.current_phase = phase.name.lower()
|
|
357
|
+
|
|
358
|
+
logger.info(f"Starting profiling for {self.current_phase} phase")
|
|
359
|
+
logger.info(f"Batch composition stats: {batch_composition_stats}")
|
|
360
|
+
self.profile_dir_with_phase_suffix = os.path.join(
|
|
361
|
+
self.profile_dir, self.current_phase)
|
|
362
|
+
|
|
363
|
+
# Create the profile subdirectory if it doesn't exist
|
|
364
|
+
os.makedirs(self.profile_dir_with_phase_suffix, exist_ok=True)
|
|
365
|
+
|
|
366
|
+
# Write the batch composition stats to a file to make it easier to
|
|
367
|
+
# align with the traces
|
|
368
|
+
self._write_batch_composition_stats_to_file_helper(
|
|
369
|
+
batch_composition_stats)
|
|
370
|
+
|
|
371
|
+
jax.profiler.start_trace(
|
|
372
|
+
self.profile_dir_with_phase_suffix,
|
|
373
|
+
profiler_options=self.default_profiling_options)
|
|
374
|
+
break
|
|
375
|
+
|
|
376
|
+
def _step_or_stop_profiling(self, batch_composition_stats: dict) -> None:
|
|
377
|
+
"""
|
|
378
|
+
Steps the profiler or stops it if we have profiled enough steps for the
|
|
379
|
+
current phase.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
batch_composition_stats: The batch composition stats, which is a dict
|
|
383
|
+
containig:
|
|
384
|
+
total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
|
|
385
|
+
num_prefill_tokens: The number of prefill tokens.
|
|
386
|
+
num_decode_tokens: The number of decode tokens.
|
|
387
|
+
padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
|
|
388
|
+
num_reqs: The number of requests in the batch.
|
|
389
|
+
"""
|
|
390
|
+
# We only should decrement the profiling_n_steps_left if we are profiling
|
|
391
|
+
if self.current_phase != "":
|
|
392
|
+
self._write_batch_composition_stats_to_file_helper(
|
|
393
|
+
batch_composition_stats)
|
|
394
|
+
self.profiling_n_steps_left -= 1
|
|
395
|
+
if self.profiling_n_steps_left <= 0:
|
|
396
|
+
jax.profiler.stop_trace()
|
|
397
|
+
logger.info(
|
|
398
|
+
f"Profiling for {self.current_phase} phase finished")
|
|
399
|
+
self.current_phase = ""
|
|
400
|
+
|
|
401
|
+
def step(self, batch_composition_stats: dict) -> None:
|
|
402
|
+
"""
|
|
403
|
+
Steps the profiler.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
batch_composition_stats: The batch composition stats, which is a dict
|
|
407
|
+
containig:
|
|
408
|
+
total_num_scheduled_tokens: The total number of tokens scheduled for the batch.
|
|
409
|
+
num_prefill_tokens: The number of prefill tokens.
|
|
410
|
+
num_decode_tokens: The number of decode tokens.
|
|
411
|
+
padded_total_num_scheduled_tokens: The padded total number of tokens scheduled for the batch.
|
|
412
|
+
num_reqs: The number of requests in the batch.
|
|
413
|
+
"""
|
|
414
|
+
have_seen_all_phases = all(self.inference_phase_seen.values())
|
|
415
|
+
# We want to start profiling only after the first trial request
|
|
416
|
+
is_past_initial_request = batch_composition_stats[
|
|
417
|
+
"num_reqs"] >= 1 and batch_composition_stats[
|
|
418
|
+
"total_num_scheduled_tokens"] > 1
|
|
419
|
+
if is_past_initial_request and (not have_seen_all_phases
|
|
420
|
+
or self.current_phase != ""):
|
|
421
|
+
# We haven't started profiling yet
|
|
422
|
+
if self.profiling_n_steps_left <= 0:
|
|
423
|
+
self._start_profiling(batch_composition_stats)
|
|
424
|
+
# We are in the middle of profiling a given phase
|
|
425
|
+
else:
|
|
426
|
+
self._step_or_stop_profiling(batch_composition_stats)
|
|
File without changes
|
|
File without changes
|