tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,672 @@
1
+ import os
2
+ import time
3
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ import vllm.envs as envs
9
+ from jax.sharding import NamedSharding, PartitionSpec
10
+ from vllm.utils import cdiv
11
+
12
+ from tpu_inference.core.disagg_utils import is_disagg_enabled
13
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
14
+ from tpu_inference.layers.jax.sample.sampling import sample
15
+ from tpu_inference.layers.jax.sample.sampling_metadata import \
16
+ TPUSupportedSamplingMetadata
17
+ from tpu_inference.logger import init_logger
18
+ from tpu_inference.utils import device_array
19
+
20
+ if TYPE_CHECKING:
21
+ from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
22
+
23
+ logger = init_logger(__name__)
24
+
25
+ # Constants for block bucketing in disaggregated utilities
26
+ BLOCK_BUCKETS = [1, 2, 4, 8, 16, 32, 64]
27
+
28
+
29
+ class CompilationManager:
30
+
31
+ def __init__(self, runner: "TPUModelRunner"):
32
+ self.runner = runner
33
+ if not envs.VLLM_DISABLE_COMPILE_CACHE:
34
+ logger.info("Enabling JAX compile cache.")
35
+ jax.config.update("jax_compilation_cache_dir",
36
+ envs.VLLM_XLA_CACHE_PATH)
37
+
38
+ def _create_dummy_tensor(self,
39
+ shape: Tuple[int, ...],
40
+ dtype: Any,
41
+ sharding: Optional[NamedSharding] = None) -> Any:
42
+ """Helper to create dummy tensors for precompilation."""
43
+ tensor = jnp.ones(shape, dtype=dtype)
44
+ if sharding:
45
+ return device_array(self.runner.mesh, tensor, sharding=sharding)
46
+ return device_array(self.runner.mesh, tensor)
47
+
48
+ def _should_skip_padding_combination(self, outer_val: int, inner_val: int,
49
+ only_equal: bool) -> bool:
50
+ """Helper to determine if we should skip this padding combination."""
51
+ if only_equal:
52
+ return inner_val != outer_val
53
+ return inner_val > outer_val
54
+
55
+ def _run_compilation(self, name: str, fn: Callable, *args,
56
+ **kwargs) -> None:
57
+ logger.info(f"Precompile {name} --> {kwargs}")
58
+ start = time.perf_counter()
59
+ result = fn(*args)
60
+ if result is not None:
61
+ if isinstance(result, tuple):
62
+ for r in result:
63
+ r.block_until_ready()
64
+ else:
65
+ result.block_until_ready()
66
+ end = time.perf_counter()
67
+ logger.info("Compilation finished in %.2f [secs].", end - start)
68
+
69
+ def capture_model(self) -> None:
70
+ if os.getenv("SKIP_JAX_PRECOMPILE", False):
71
+ return
72
+ logger.info("Precompile all the subgraphs with possible input shapes.")
73
+
74
+ with self.runner.maybe_setup_dummy_loras(self.runner.lora_config):
75
+ self._precompile_backbone_text_only()
76
+ if self.runner.is_multimodal_model:
77
+ self._precompile_backbone_with_inputs_embeds()
78
+ self._precompile_select_from_array()
79
+ self._precompile_compute_logits()
80
+ self._precompile_disagg_utils()
81
+ self._precompile_sampling()
82
+ self._precompile_gather_logprobs()
83
+ self._precompile_structured_decoding()
84
+ if self.runner.speculative_config:
85
+ self._precompile_speculative_decoding()
86
+
87
+ def _precompile_backbone_helper(self, name, *, input_ids, positions,
88
+ inputs_embeds) -> None:
89
+ num_tokens = None
90
+ if input_ids is not None:
91
+ num_tokens = input_ids.shape[0]
92
+ elif inputs_embeds is not None:
93
+ num_tokens = inputs_embeds.shape[0]
94
+ assert num_tokens is not None
95
+
96
+ # Keep existing pattern for complex array operations
97
+ block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
98
+ block_tables = block_tables.reshape(-1)
99
+ block_tables = device_array(self.runner.mesh, block_tables)
100
+
101
+ seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
102
+ jnp.int32)
103
+ query_start_loc = self._create_dummy_tensor(
104
+ (self.runner.max_num_reqs + 1, ), jnp.int32)
105
+
106
+ # Keep existing pattern for specific value arrays
107
+ request_distribution = np.array([0, 0, 0], dtype=np.int32)
108
+ request_distribution = device_array(self.runner.mesh,
109
+ request_distribution)
110
+
111
+ attention_metadata = AttentionMetadata(
112
+ input_positions=positions,
113
+ block_tables=block_tables,
114
+ seq_lens=seq_lens,
115
+ query_start_loc=query_start_loc,
116
+ request_distribution=request_distribution,
117
+ )
118
+
119
+ def model_fn_wrapper(
120
+ state,
121
+ kv_caches,
122
+ input_ids,
123
+ attention_metadata,
124
+ inputs_embeds,
125
+ layer_name_to_kvcache_index,
126
+ lora_metadata,
127
+ ):
128
+ kv_caches, hidden_states, aux_hidden_states = self.runner.model_fn(
129
+ state, kv_caches, input_ids, attention_metadata, inputs_embeds,
130
+ layer_name_to_kvcache_index, lora_metadata)
131
+ self.runner.kv_caches = kv_caches
132
+ return hidden_states
133
+
134
+ with self.runner.maybe_select_dummy_loras(
135
+ self.runner.lora_config, np.array([num_tokens],
136
+ dtype=np.int32)):
137
+ lora_metadata = self.runner.lora_utils.extract_lora_metadata()
138
+ self._run_compilation(
139
+ name,
140
+ model_fn_wrapper,
141
+ self.runner.state,
142
+ self.runner.kv_caches,
143
+ input_ids,
144
+ attention_metadata,
145
+ inputs_embeds,
146
+ tuple(self.runner.layer_name_to_kvcache_index.items()),
147
+ lora_metadata,
148
+ num_tokens=num_tokens,
149
+ )
150
+
151
+ def _precompile_backbone_text_only(self) -> None:
152
+ for num_tokens in self.runner.num_tokens_paddings:
153
+ input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
154
+ positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
155
+ self._precompile_backbone_helper("backbone",
156
+ input_ids=input_ids,
157
+ positions=positions,
158
+ inputs_embeds=None)
159
+
160
+ def _precompile_backbone_with_inputs_embeds(self) -> None:
161
+ hidden_size = self.runner.model_config.get_hidden_size()
162
+ dtype = self.runner.model_config.dtype
163
+ for num_tokens in self.runner.num_tokens_paddings:
164
+ inputs_embeds = self._create_dummy_tensor(
165
+ (num_tokens, hidden_size), dtype)
166
+ if self.runner.uses_mrope:
167
+ positions = self._create_dummy_tensor((3, num_tokens),
168
+ jnp.int32)
169
+ else:
170
+ positions = self._create_dummy_tensor((num_tokens, ),
171
+ jnp.int32)
172
+ self._precompile_backbone_helper("backbone with embeds",
173
+ input_ids=None,
174
+ positions=positions,
175
+ inputs_embeds=inputs_embeds)
176
+
177
+ def _precompile_select_from_array_helper(
178
+ self,
179
+ name: str,
180
+ source_paddings: List[int],
181
+ indices_paddings: List[int],
182
+ hidden_dim: int,
183
+ sharding: Optional[NamedSharding] = None,
184
+ only_equal_paddings: bool = False,
185
+ check_should_skip_padding: bool = True,
186
+ ) -> None:
187
+ """Precompile select_from_array operations with various input shape combinations.
188
+
189
+ This helper method generates and precompiles the select_from_array function for different
190
+ combinations of array sizes and index counts. The operation being precompiled is
191
+ array[indices] where:
192
+ - array has shape (array_size, hidden_dim)
193
+ - indices has shape (indices_count,)
194
+ - result has shape (indices_count, hidden_dim)
195
+
196
+ This is essential for TPU compilation as JAX needs to precompile functions with all
197
+ possible input shapes that will be encountered during runtime.
198
+
199
+ Args:
200
+ name: Descriptive name for logging purposes (e.g., "select all logits")
201
+ source_paddings: List of possible sizes for the array being indexed (first dimension)
202
+ indices_paddings: List of possible counts of indices to select
203
+ hidden_dim: Second dimension size of the array (e.g., hidden_size or vocab_size)
204
+ sharding: Optional sharding specification for distributed computation
205
+ only_equal_paddings: If True, only compile when array size equals indices count
206
+ check_should_skip_padding: If True, check whether to skip certain padding combinations to reduce compilation time
207
+ """
208
+ logger.info(f"Compiling select_from_array for {name}.")
209
+ for array_size in source_paddings:
210
+ for indices_count in indices_paddings:
211
+ if check_should_skip_padding and self._should_skip_padding_combination(
212
+ array_size, indices_count, only_equal_paddings):
213
+ continue
214
+
215
+ input_tensor = self._create_dummy_tensor(
216
+ (array_size, hidden_dim), jnp.bfloat16, sharding)
217
+ indices_to_select = self._create_dummy_tensor(
218
+ (indices_count, ), jnp.int32)
219
+
220
+ self._run_compilation(
221
+ f"select_from_array [{name}]",
222
+ self.runner._select_from_array_fn, input_tensor,
223
+ indices_to_select, **{
224
+ "array_size": array_size,
225
+ "index_size": indices_count
226
+ })
227
+
228
+ def _precompile_select_from_array(self) -> None:
229
+ logger.info("Compiling select_from_array with different input shapes.")
230
+ hsize = self.runner.model_config.get_hidden_size()
231
+
232
+ if self.runner.speculative_config:
233
+ index_paddings = self.runner.num_logits_paddings
234
+ else:
235
+ index_paddings = self.runner.num_reqs_paddings
236
+
237
+ self._precompile_select_from_array_helper(
238
+ name="select all logits",
239
+ source_paddings=self.runner.num_tokens_paddings,
240
+ indices_paddings=index_paddings,
241
+ hidden_dim=hsize,
242
+ sharding=NamedSharding(self.runner.mesh, PartitionSpec(None,
243
+ None)),
244
+ )
245
+
246
+ if self.runner.speculative_config:
247
+ vocab_size = self.runner.model_config.get_vocab_size()
248
+ self._precompile_select_from_array_helper(
249
+ name="select bonus tokens for spec decoding",
250
+ source_paddings=self.runner.num_logits_paddings,
251
+ indices_paddings=self.runner.num_reqs_paddings,
252
+ hidden_dim=vocab_size,
253
+ sharding=NamedSharding(self.runner.mesh,
254
+ PartitionSpec(None, "model")),
255
+ )
256
+ self._precompile_select_from_array_helper(
257
+ name="select target tokens for spec decoding",
258
+ source_paddings=self.runner.num_logits_paddings,
259
+ indices_paddings=self.runner.num_logits_paddings,
260
+ hidden_dim=vocab_size,
261
+ sharding=NamedSharding(self.runner.mesh,
262
+ PartitionSpec(None, "model")),
263
+ only_equal_paddings=True,
264
+ )
265
+
266
+ self._precompile_select_from_array_helper(
267
+ name="select hidden states for eagle-3",
268
+ source_paddings=self.runner.num_tokens_paddings,
269
+ indices_paddings=[self.runner.max_num_reqs],
270
+ hidden_dim=hsize,
271
+ sharding=NamedSharding(self.runner.mesh,
272
+ PartitionSpec(None, None)),
273
+ check_should_skip_padding=False,
274
+ )
275
+
276
+ def _precompile_compute_logits(self) -> None:
277
+ logger.info("Compiling compute_logits with different input shapes.")
278
+ hsize = self.runner.model_config.get_hidden_size()
279
+ leading_shape = self.runner.num_reqs_paddings if not self.runner.speculative_config else self.runner.num_logits_paddings
280
+ for num_reqs in leading_shape:
281
+ hidden_states = self._create_dummy_tensor((num_reqs, hsize),
282
+ jnp.bfloat16)
283
+ with self.runner.maybe_select_dummy_loras(
284
+ self.runner.lora_config,
285
+ np.array([num_reqs], dtype=np.int32)):
286
+ lora_metadata = self.runner.lora_utils.extract_lora_metadata()
287
+ self._run_compilation(
288
+ "compute_logits",
289
+ self.runner.compute_logits_fn,
290
+ self.runner.state,
291
+ hidden_states,
292
+ lora_metadata,
293
+ num_reqs=num_reqs,
294
+ )
295
+
296
+ def _precompile_sampling(self) -> None:
297
+ logger.info("Compiling sampling with different input shapes.")
298
+ hsize = self.runner.model_config.get_vocab_size()
299
+ for num_reqs in self.runner.num_reqs_paddings:
300
+ sharding = NamedSharding(self.runner.mesh,
301
+ PartitionSpec(None, "model"))
302
+ logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
303
+ sharding)
304
+ for do_sampling in (True, False):
305
+ if do_sampling:
306
+ temperature = np.full((num_reqs, ), 0.7, dtype=np.float32)
307
+ top_k = np.full((num_reqs, ), 20, dtype=np.int32)
308
+ top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
309
+ (temperature, top_k,
310
+ top_p) = device_array(self.runner.mesh,
311
+ (temperature, top_k, top_p))
312
+ else:
313
+ temperature = None
314
+ top_k = None
315
+ top_p = None
316
+
317
+ sampling_metadata = TPUSupportedSamplingMetadata(
318
+ temperature=temperature,
319
+ top_k=top_k,
320
+ top_p=top_p,
321
+ do_sampling=do_sampling,
322
+ )
323
+ self._run_compilation(
324
+ "sample",
325
+ sample,
326
+ self.runner.rng_params_for_sampling,
327
+ self.runner.mesh,
328
+ logits,
329
+ sampling_metadata,
330
+ num_reqs=num_reqs,
331
+ do_sampling=do_sampling,
332
+ )
333
+
334
+ def _precompile_disagg_utils(self) -> None:
335
+ if not is_disagg_enabled():
336
+ return
337
+ logger.info(
338
+ "Compiling disaggregated util with different input shapes.")
339
+ block_size = self.runner.block_size
340
+ for num_blocks in range(1, self.runner.max_num_blocks_per_req // 2):
341
+ logger.info(
342
+ f"Precompile slice and insert for num_blocks {num_blocks}")
343
+ block_numbers = list(range(1, num_blocks + 1))
344
+ kv_cache_slices = self.runner.kv_cache_manager.get_kv_cache_for_block_ids(
345
+ block_numbers)
346
+ # Prevent the slices from getting freed by insert before finishing this operation
347
+ for layer_cache in kv_cache_slices:
348
+ layer_cache.block_until_ready()
349
+ self.runner.kv_caches = self.runner.kv_cache_manager._jitted_insert_continuous_kv_cache(
350
+ block_size,
351
+ self.runner.kv_caches,
352
+ kv_cache_slices,
353
+ block_numbers[0],
354
+ )
355
+ for layer_cache in self.runner.kv_caches:
356
+ layer_cache.block_until_ready()
357
+
358
+ def _precompile_gather_logprobs(self) -> None:
359
+ logger.info("Compiling gather_logprobs with different input shapes.")
360
+ hsize = self.runner.model_config.get_vocab_size()
361
+ for num_reqs in self.runner.num_reqs_paddings:
362
+ logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
363
+ token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
364
+ self._run_compilation(
365
+ "gather_logprobs",
366
+ self.runner._compute_and_gather_logprobs,
367
+ logits,
368
+ token_ids,
369
+ self.runner.model_config.max_logprobs,
370
+ num_reqs=num_reqs,
371
+ )
372
+
373
+ def _precompile_speculative_decoding(self) -> None:
374
+ logger.info(
375
+ "Compiling speculative_decoding with different input shapes.")
376
+ self._precompile_rejection_sampler()
377
+ if self.runner.speculative_config.method == "eagle3":
378
+ self._precompile_eagle3_helpers()
379
+
380
+ def _precompile_rejection_sampler(self) -> None:
381
+ logger.info("Compiling rejection_sampler with different input shapes.")
382
+ vocab_size = self.runner.model_config.get_vocab_size()
383
+ for num_logits in self.runner.num_logits_paddings:
384
+ for num_reqs in self.runner.num_reqs_paddings:
385
+ sharding = NamedSharding(self.runner.mesh,
386
+ PartitionSpec(None, "model"))
387
+ target_probs = self._create_dummy_tensor(
388
+ (num_logits, vocab_size), jnp.bfloat16, sharding)
389
+ draft_token_ids = self._create_dummy_tensor((num_logits, ),
390
+ jnp.int32)
391
+ num_draft_tokens = self._create_dummy_tensor((num_reqs, ),
392
+ jnp.int32)
393
+ bonus_token_ids = self._create_dummy_tensor((num_reqs, ),
394
+ jnp.int32)
395
+
396
+ for do_sampling in (False, True):
397
+ draft_probs = None
398
+ if do_sampling:
399
+ compilation_name = "random_rejection_sampler"
400
+ temperature = self._create_dummy_tensor((num_reqs, ),
401
+ np.float32)
402
+ top_k = self._create_dummy_tensor((num_reqs, ),
403
+ np.int32)
404
+ top_p = self._create_dummy_tensor((num_reqs, ),
405
+ np.float32)
406
+ sampling_metadata = TPUSupportedSamplingMetadata(
407
+ temperature=temperature,
408
+ top_k=top_k,
409
+ top_p=top_p,
410
+ do_sampling=do_sampling)
411
+ else:
412
+ compilation_name = "greedy_rejection_sampler"
413
+ sampling_metadata = TPUSupportedSamplingMetadata(
414
+ do_sampling=do_sampling)
415
+
416
+ self._run_compilation(
417
+ compilation_name,
418
+ self.runner.rejection_sampler,
419
+ draft_token_ids,
420
+ num_draft_tokens,
421
+ draft_probs,
422
+ target_probs,
423
+ bonus_token_ids,
424
+ sampling_metadata,
425
+ self.runner.rng_params_for_sampling,
426
+ num_logits=num_logits,
427
+ num_reqs=num_reqs,
428
+ do_sampling=do_sampling,
429
+ )
430
+
431
+ def _precompile_eagle3_helpers(self) -> None:
432
+ logger.info(
433
+ "Compiling eagle3 jitted helpers with different input shapes.")
434
+ hidden_size = self.runner.model_config.get_hidden_size()
435
+ draft_hidden_size = self.runner.vllm_config.speculative_config.draft_model_config.hf_config.hidden_size * 3
436
+ dtype = self.runner.model_config.dtype
437
+
438
+ num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
439
+ draft_kv_cache_group_id = num_kv_cache_groups - 1
440
+ block_tables = jnp.ones(
441
+ (self.runner.max_num_reqs,
442
+ cdiv(self.runner.max_model_len, self.runner.block_size)),
443
+ jnp.int32)
444
+ self._run_compilation(
445
+ "eagle3_reshape_block",
446
+ self.runner.drafter._reshape_block_tables,
447
+ block_tables,
448
+ )
449
+ block_tables = self.runner.input_batch.block_table[
450
+ draft_kv_cache_group_id].get_device_tensor().reshape(-1)
451
+ block_tables_loop = jax.device_put(
452
+ block_tables, NamedSharding(self.runner.mesh,
453
+ PartitionSpec(None, )))
454
+
455
+ selected_positions = self._create_dummy_tensor(
456
+ (self.runner.max_num_reqs, ), jnp.int32)
457
+ seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
458
+ jnp.int32)
459
+ query_start_loc = self._create_dummy_tensor(
460
+ (self.runner.max_num_reqs + 1, ), jnp.int32)
461
+ self._run_compilation("_prepare_input_loop for the first loop",
462
+ self.runner.drafter._prepare_input_loop,
463
+ selected_positions, seq_lens, block_tables)
464
+ self._run_compilation("_prepare_input_loop for the subsequent loops",
465
+ self.runner.drafter._prepare_input_loop,
466
+ selected_positions, seq_lens, block_tables_loop)
467
+
468
+ request_distribution = np.array([0, 0, 0], dtype=np.int32)
469
+ request_distribution = device_array(self.runner.mesh,
470
+ request_distribution)
471
+
472
+ for num_reqs_padding in self.runner.num_reqs_paddings:
473
+ logits = self._create_dummy_tensor(
474
+ (num_reqs_padding, self.runner.vocab_size), jnp.bfloat16,
475
+ NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
476
+ self._run_compilation(
477
+ "_get_draft_token_ids",
478
+ self.runner.drafter._get_draft_token_ids,
479
+ logits,
480
+ num_reqs=num_reqs_padding,
481
+ )
482
+ self._run_compilation(
483
+ "convert_list_to_device_array",
484
+ self.runner.speculative_decoding_manager.
485
+ _convert_list_to_device_array,
486
+ [0] * num_reqs_padding,
487
+ num_reqs=num_reqs_padding,
488
+ )
489
+ for i in range(1, self.runner.drafter.num_speculative_tokens + 1):
490
+ draft_token_ids_list = [
491
+ self._create_dummy_tensor(
492
+ (num_reqs_padding, ), jnp.int32,
493
+ NamedSharding(self.runner.mesh, PartitionSpec()))
494
+ for _ in range(i)
495
+ ]
496
+ self._run_compilation(
497
+ "_stack_draft_token_ids",
498
+ self.runner.drafter._stack_draft_token_ids,
499
+ draft_token_ids_list,
500
+ num_reqs=num_reqs_padding,
501
+ draft_token_ids_list_length=len(draft_token_ids_list))
502
+
503
+ for num_logits in self.runner.num_logits_paddings:
504
+ hidden_states = self._create_dummy_tensor(
505
+ (num_logits, hidden_size), jnp.bfloat16)
506
+ self._run_compilation(
507
+ "drafter_compute_logits",
508
+ self.runner.drafter.compute_logits_fn,
509
+ self.runner.drafter.state,
510
+ hidden_states,
511
+ None,
512
+ num_logits=num_logits,
513
+ )
514
+
515
+ position_indices = self._create_dummy_tensor(
516
+ (self.runner.max_num_reqs, ), jnp.int32)
517
+ next_token_ids = self._create_dummy_tensor(
518
+ (self.runner.max_num_reqs, ), jnp.int32)
519
+ input_ids_loop = self._create_dummy_tensor(
520
+ (self.runner.max_num_reqs, ), jnp.int32,
521
+ NamedSharding(self.runner.mesh, PartitionSpec()))
522
+ target_hidden_state_loop = self._create_dummy_tensor(
523
+ (self.runner.max_num_reqs, hidden_size), dtype,
524
+ NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
525
+ for num_tokens in self.runner.num_tokens_paddings:
526
+ positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
527
+ self._run_compilation(
528
+ "select_from_array [select input positions for eagle3]",
529
+ self.runner._select_from_array_fn,
530
+ positions,
531
+ position_indices,
532
+ num_tokens=num_tokens)
533
+
534
+ aux_hidden_states = [
535
+ self._create_dummy_tensor((num_tokens, hidden_size), dtype),
536
+ self._create_dummy_tensor((num_tokens, hidden_size), dtype),
537
+ self._create_dummy_tensor((num_tokens, hidden_size), dtype),
538
+ ]
539
+ self._run_compilation(
540
+ "eagle3_concate_hidden_states",
541
+ self.runner.drafter._concate_hidden_states,
542
+ aux_hidden_states,
543
+ num_tokens=num_tokens,
544
+ )
545
+
546
+ input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
547
+ aux_hidden_states = [
548
+ self._create_dummy_tensor(
549
+ (num_tokens, hidden_size), jnp.bfloat16,
550
+ NamedSharding(self.runner.mesh, PartitionSpec(None,
551
+ None))),
552
+ self._create_dummy_tensor(
553
+ (num_tokens, hidden_size), jnp.bfloat16,
554
+ NamedSharding(self.runner.mesh, PartitionSpec(None,
555
+ None))),
556
+ self._create_dummy_tensor(
557
+ (num_tokens, hidden_size), jnp.bfloat16,
558
+ NamedSharding(self.runner.mesh, PartitionSpec(None,
559
+ None))),
560
+ ]
561
+ for num_indices in self.runner.num_tokens_paddings:
562
+ indices = jnp.ones((num_indices, ), dtype=jnp.int32)
563
+ self._run_compilation(
564
+ "select_from_array [select input ids for eagle3]",
565
+ self.runner._select_from_array_fn,
566
+ input_ids,
567
+ indices,
568
+ num_tokens=num_tokens,
569
+ num_indices=num_indices)
570
+ self._run_compilation(
571
+ "select_from_array [select hidden states for eagle3]",
572
+ self.runner.drafter._select_target_hidden_states,
573
+ aux_hidden_states, indices)
574
+
575
+ attention_metadata = AttentionMetadata(
576
+ input_positions=positions,
577
+ block_tables=block_tables,
578
+ seq_lens=seq_lens,
579
+ query_start_loc=query_start_loc,
580
+ request_distribution=request_distribution,
581
+ )
582
+ target_hidden_states = self._create_dummy_tensor(
583
+ (num_tokens, hidden_size), dtype,
584
+ NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
585
+
586
+ def draft_model_fn_wrapper(
587
+ state,
588
+ kv_caches,
589
+ input_ids,
590
+ target_hidden_states,
591
+ attention_metadata,
592
+ ):
593
+ kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
594
+ state, kv_caches, input_ids, target_hidden_states,
595
+ attention_metadata)
596
+ self.runner.kv_caches = kv_caches
597
+ return hidden_states
598
+
599
+ self._run_compilation(
600
+ "draft_model_fn",
601
+ draft_model_fn_wrapper,
602
+ self.runner.drafter.state,
603
+ self.runner.kv_caches,
604
+ input_ids,
605
+ target_hidden_states,
606
+ attention_metadata,
607
+ num_tokens=num_tokens,
608
+ )
609
+
610
+ attention_metadata.query_start_loc = jax.device_put(
611
+ attention_metadata.query_start_loc,
612
+ NamedSharding(self.runner.mesh, PartitionSpec()))
613
+ attention_metadata.block_tables = block_tables_loop
614
+ attention_metadata.input_positions = self._create_dummy_tensor(
615
+ (self.runner.max_num_reqs, ), jnp.int32)
616
+ self._run_compilation(
617
+ "draft_model_fn in a loop",
618
+ draft_model_fn_wrapper,
619
+ self.runner.drafter.state,
620
+ self.runner.kv_caches,
621
+ input_ids_loop,
622
+ target_hidden_state_loop,
623
+ attention_metadata,
624
+ num_tokens=num_tokens,
625
+ )
626
+
627
+ target_hidden_states = self._create_dummy_tensor(
628
+ (num_tokens, draft_hidden_size), dtype)
629
+ self._run_compilation(
630
+ "draft_model_combine_hidden_states_fn",
631
+ self.runner.drafter.combine_hidden_states_fn,
632
+ self.runner.drafter.state,
633
+ target_hidden_states,
634
+ num_tokens=num_tokens,
635
+ )
636
+
637
+ target_token_ids = self._create_dummy_tensor((num_tokens, ),
638
+ jnp.int32)
639
+ self._run_compilation(
640
+ "_prepare_input_ids",
641
+ self.runner.drafter._prepare_input_ids,
642
+ query_start_loc,
643
+ target_token_ids,
644
+ next_token_ids,
645
+ self.runner.input_batch.num_reqs,
646
+ )
647
+
648
+ def _precompile_structured_decoding(self) -> None:
649
+ logger.info(
650
+ "Compiling structured_decoding with different input shapes.")
651
+ for num_reqs in self.runner.num_reqs_paddings:
652
+ dummy_logits = self._create_dummy_tensor(
653
+ (num_reqs, self.runner.vocab_size), jnp.bfloat16)
654
+ dummy_require_struct_decoding = self.runner.require_structured_out_cpu[:
655
+ num_reqs]
656
+ dummy_grammar_bitmask = self.runner.grammar_bitmask_cpu[:num_reqs]
657
+
658
+ (dummy_logits, dummy_require_struct_decoding,
659
+ dummy_grammar_bitmask, arange) = device_array(
660
+ self.runner.mesh,
661
+ (dummy_logits, dummy_require_struct_decoding,
662
+ dummy_grammar_bitmask, self.runner.structured_decode_arange))
663
+
664
+ self._run_compilation(
665
+ "structured_decode",
666
+ self.runner.structured_decoding_manager.structured_decode_fn,
667
+ dummy_require_struct_decoding,
668
+ dummy_grammar_bitmask,
669
+ dummy_logits,
670
+ arange,
671
+ num_reqs=num_reqs,
672
+ )