tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,22 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
3
+ import os
4
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
4
5
 
5
6
  import jax.numpy as jnp
6
- import torch
7
7
  import vllm.envs as vllm_envs
8
+ from torchax.ops.mappings import j2t_dtype
8
9
  from tpu_info import device
9
10
  from vllm.inputs import ProcessorInputs, PromptType
10
11
  from vllm.platforms.interface import Platform, PlatformEnum
11
12
  from vllm.sampling_params import SamplingParams, SamplingType
12
13
 
13
14
  from tpu_inference import envs
14
- from tpu_inference.layers.common.sharding import ShardingConfigManager
15
+ from tpu_inference.layers.jax.sharding import ShardingConfigManager
15
16
  from tpu_inference.logger import init_logger
16
- from tpu_inference.utils import to_jax_dtype, to_torch_dtype
17
17
 
18
18
  if TYPE_CHECKING:
19
- from vllm.attention.backends.registry import AttentionBackendEnum
19
+ from vllm.attention.backends.registry import _Backend
20
20
  from vllm.config import BlockSize, ModelConfig, VllmConfig
21
21
  from vllm.pooling_params import PoolingParams
22
22
  else:
@@ -24,10 +24,16 @@ else:
24
24
  ModelConfig = None
25
25
  VllmConfig = None
26
26
  PoolingParams = None
27
- AttentionBackendEnum = None
27
+ _Backend = None
28
28
 
29
29
  logger = init_logger(__name__)
30
30
 
31
+ _DTYPE: dict[str, jnp.dtype] = {
32
+ "bfloat16": jnp.bfloat16,
33
+ "float": jnp.float32,
34
+ "float32": jnp.float32,
35
+ }
36
+
31
37
 
32
38
  class TpuPlatform(Platform):
33
39
  _enum = PlatformEnum.TPU
@@ -48,13 +54,12 @@ class TpuPlatform(Platform):
48
54
  ]
49
55
 
50
56
  @classmethod
51
- def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
52
- head_size: int, dtype: jnp.dtype,
53
- kv_cache_dtype: Optional[str], block_size: int,
54
- use_v1: bool, use_mla: bool, has_sink: bool,
55
- use_sparse: bool, attn_type: Any) -> str:
56
- from vllm.attention.backends.registry import AttentionBackendEnum
57
- if selected_backend != AttentionBackendEnum.PALLAS:
57
+ def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
58
+ dtype: jnp.dtype, kv_cache_dtype: Optional[str],
59
+ block_size: int, use_v1: bool, use_mla: bool,
60
+ has_sink: bool, use_sparse: bool) -> str:
61
+ from vllm.attention.backends.registry import _Backend
62
+ if selected_backend != _Backend.PALLAS:
58
63
  logger.info("Cannot use %s backend on TPU.", selected_backend)
59
64
 
60
65
  if use_v1:
@@ -77,14 +82,6 @@ class TpuPlatform(Platform):
77
82
  logger.warning(f"Error getting device name: {e}")
78
83
  return 'TPU'
79
84
 
80
- @classmethod
81
- def fp8_dtype(cls) -> torch.dtype:
82
- if cls.get_device_name().lower() == "tpu v6e":
83
- logger.info(
84
- "Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
85
- return torch.float8_e5m2
86
- return torch.float8_e4m3fn
87
-
88
85
  @classmethod
89
86
  def get_device_total_memory(cls, device_id: int = 0) -> int:
90
87
  raise NotImplementedError
@@ -135,7 +132,6 @@ class TpuPlatform(Platform):
135
132
  # For v0, the default block size is 16.
136
133
  if cache_config and cache_config.block_size is None:
137
134
  cache_config.block_size = cast(BlockSize, 16)
138
-
139
135
  compilation_config = vllm_config.compilation_config
140
136
 
141
137
  # TPU only supports DYNAMO_TRACE_ONCE compilation level
@@ -152,19 +148,20 @@ class TpuPlatform(Platform):
152
148
  # NOTE(xiang): convert dtype to jnp.dtype
153
149
  # NOTE(wenlong): skip this logic for mm model preprocessing
154
150
  # For mm model preprocessors, it may need the output dtype to be torch.
155
- # In order to avoid a PR to vLLM, we postpone the dtype checking during
156
- # tpu_worker initialization
151
+ # In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
157
152
  if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
158
- model_dtype = vllm_config.model_config.dtype
159
- try:
160
- dtype = to_jax_dtype(model_dtype)
161
- except ValueError:
162
- logger.warning(f"{model_dtype=} is not supported. "
163
- "Falling back to jnp.bfloat16")
164
- dtype = jnp.bfloat16
165
- if impl == "vllm":
166
- dtype = to_torch_dtype(dtype)
167
- vllm_config.model_config.dtype = dtype
153
+ if not isinstance(vllm_config.model_config.dtype, str):
154
+ logger.warning(
155
+ "The model dtype is not properly set for JAX backend. "
156
+ "Overwriting it to jnp.bfloat16")
157
+ vllm_config.model_config.dtype = jnp.bfloat16
158
+ else:
159
+ vllm_config.model_config.dtype = _DTYPE.get(
160
+ vllm_config.model_config.dtype, jnp.bfloat16)
161
+
162
+ if impl == "vllm":
163
+ vllm_config.model_config.dtype = j2t_dtype(
164
+ vllm_config.model_config.dtype.dtype)
168
165
 
169
166
  # TODO(cuiq): remove this dependency.
170
167
  from vllm.v1.attention.backends.pallas import PallasAttentionBackend
@@ -185,16 +182,10 @@ class TpuPlatform(Platform):
185
182
  parallel_config.worker_cls = \
186
183
  "tpu_inference.worker.tpu_worker.TPUWorker"
187
184
 
188
- multihost_backend = envs.TPU_MULTIHOST_BACKEND
185
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
189
186
  if not multihost_backend: # Single host
190
- if parallel_config.pipeline_parallel_size == 1:
191
- logger.info("Force using UniProcExecutor for JAX on \
192
- single host without pipeline parallelism.")
193
- parallel_config.distributed_executor_backend = "uni"
194
- else:
195
- logger.info("Force using MultiprocExecutor for JAX on \
196
- single host with pipeline parallelism.")
197
- parallel_config.distributed_executor_backend = "mp"
187
+ logger.info("Force using UniProcExecutor for JAX on single host.")
188
+ parallel_config.distributed_executor_backend = "uni"
198
189
  elif multihost_backend == "ray":
199
190
  from tpu_inference.executors.ray_distributed_executor import \
200
191
  RayDistributedExecutor
@@ -269,7 +260,3 @@ class TpuPlatform(Platform):
269
260
  Returns if the current platform needs to sync weight loader.
270
261
  """
271
262
  return True
272
-
273
- @classmethod
274
- def support_hybrid_kv_cache(cls) -> bool:
275
- return True
@@ -1,22 +1,20 @@
1
+ import os
1
2
  import time
2
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
3
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
3
4
 
4
5
  import jax
5
6
  import jax.numpy as jnp
6
7
  import numpy as np
7
- import vllm.envs as vllm_envs
8
+ import vllm.envs as envs
8
9
  from jax.sharding import NamedSharding, PartitionSpec
9
10
 
10
- import tpu_inference.envs as envs
11
11
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
- from tpu_inference.layers.common.sharding import ShardingAxisName
14
13
  from tpu_inference.layers.jax.sample.sampling import sample
15
14
  from tpu_inference.layers.jax.sample.sampling_metadata import \
16
15
  TPUSupportedSamplingMetadata
16
+ from tpu_inference.layers.jax.sharding import ShardingAxisName
17
17
  from tpu_inference.logger import init_logger
18
- from tpu_inference.models.jax.jax_intermediate_tensor import \
19
- JaxIntermediateTensors
20
18
  from tpu_inference.utils import device_array
21
19
 
22
20
  if TYPE_CHECKING:
@@ -32,10 +30,10 @@ class CompilationManager:
32
30
 
33
31
  def __init__(self, runner: "TPUModelRunner"):
34
32
  self.runner = runner
35
- if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
33
+ if not envs.VLLM_DISABLE_COMPILE_CACHE:
36
34
  logger.info("Enabling JAX compile cache.")
37
35
  jax.config.update("jax_compilation_cache_dir",
38
- vllm_envs.VLLM_XLA_CACHE_PATH)
36
+ envs.VLLM_XLA_CACHE_PATH)
39
37
 
40
38
  def _create_dummy_tensor(self,
41
39
  shape: Tuple[int, ...],
@@ -69,7 +67,8 @@ class CompilationManager:
69
67
  logger.info("Compilation finished in %.2f [secs].", end - start)
70
68
 
71
69
  def capture_model(self) -> None:
72
- if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
70
+ if os.getenv("SKIP_JAX_PRECOMPILE",
71
+ False) or self.runner.model_config.enforce_eager:
73
72
  return
74
73
  logger.info("Precompile all the subgraphs with possible input shapes.")
75
74
 
@@ -82,8 +81,6 @@ class CompilationManager:
82
81
  self._precompile_backbone_with_inputs_embeds()
83
82
  if self.runner.scheduler_config.async_scheduling:
84
83
  self._precompile_substitute_placeholder_token()
85
- if not self.runner.is_last_rank:
86
- return
87
84
  self._precompile_select_from_array()
88
85
  self._precompile_compute_logits()
89
86
  self._precompile_disagg_utils()
@@ -123,15 +120,8 @@ class CompilationManager:
123
120
  num_tokens=num_tokens,
124
121
  )
125
122
 
126
- def _precompile_backbone_helper(self,
127
- name,
128
- *,
129
- input_ids,
130
- positions,
131
- inputs_embeds,
132
- intermediate_tensors=None,
133
- is_first_rank=True,
134
- is_last_rank=True) -> None:
123
+ def _precompile_backbone_helper(self, name, *, input_ids, positions,
124
+ inputs_embeds) -> None:
135
125
  num_tokens = None
136
126
  if input_ids is not None:
137
127
  num_tokens = input_ids.shape[0]
@@ -145,6 +135,12 @@ class CompilationManager:
145
135
  ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
146
136
 
147
137
  # Keep existing pattern for complex array operations
138
+ block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
139
+ block_tables = block_tables.reshape(-1)
140
+ block_tables = device_array(self.runner.mesh,
141
+ block_tables,
142
+ sharding=dp_sharding)
143
+
148
144
  seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
149
145
  jnp.int32, dp_sharding)
150
146
  query_start_loc = self._create_dummy_tensor(
@@ -156,49 +152,26 @@ class CompilationManager:
156
152
  request_distribution,
157
153
  sharding=dp_sharding)
158
154
 
159
- attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
160
- uniform_attention_metadata: AttentionMetadata = None
161
- for kv_cache_gid, kv_cache_group in enumerate(
162
- self.runner.kv_cache_config.kv_cache_groups):
163
- block_tables = self.runner.block_tables_cpu[
164
- kv_cache_gid][:self.runner.max_num_reqs]
165
- block_tables = block_tables.reshape(-1)
166
- block_tables = device_array(self.runner.mesh,
167
- block_tables,
168
- sharding=dp_sharding)
169
-
170
- attention_metadata_gid = AttentionMetadata(
171
- input_positions=positions,
172
- block_tables=block_tables,
173
- seq_lens=seq_lens,
174
- query_start_loc=query_start_loc,
175
- request_distribution=request_distribution,
176
- )
177
- if not self.runner.use_hybrid_kvcache:
178
- # all layers share the same attention metadata
179
- uniform_attention_metadata = attention_metadata_gid
180
- else:
181
- for layer_name in kv_cache_group.layer_names:
182
- attention_metadata_per_layer[
183
- layer_name] = attention_metadata_gid
155
+ attention_metadata = AttentionMetadata(
156
+ input_positions=positions,
157
+ block_tables=block_tables,
158
+ seq_lens=seq_lens,
159
+ query_start_loc=query_start_loc,
160
+ request_distribution=request_distribution,
161
+ )
184
162
 
185
163
  def model_fn_wrapper(
186
164
  state,
187
165
  kv_caches,
188
166
  input_ids,
189
167
  attention_metadata,
190
- positions,
191
168
  inputs_embeds,
192
169
  layer_name_to_kvcache_index,
193
170
  lora_metadata,
194
- intermediate_tensors,
195
- is_first_rank,
196
- is_last_rank,
197
171
  ):
198
172
  kv_caches, hidden_states, _ = self.runner.model_fn(
199
173
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
200
- positions, layer_name_to_kvcache_index, lora_metadata,
201
- intermediate_tensors, is_first_rank, is_last_rank)
174
+ layer_name_to_kvcache_index, lora_metadata)
202
175
  self.runner.kv_caches = kv_caches
203
176
  return hidden_states
204
177
 
@@ -206,10 +179,6 @@ class CompilationManager:
206
179
  self.runner.lora_config, np.array([num_tokens],
207
180
  dtype=np.int32)):
208
181
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
209
- if self.runner.use_hybrid_kvcache:
210
- attention_metadata = attention_metadata_per_layer
211
- else:
212
- attention_metadata = uniform_attention_metadata
213
182
  self._run_compilation(
214
183
  name,
215
184
  model_fn_wrapper,
@@ -217,13 +186,9 @@ class CompilationManager:
217
186
  self.runner.kv_caches,
218
187
  input_ids,
219
188
  attention_metadata,
220
- positions,
221
189
  inputs_embeds,
222
190
  tuple(self.runner.layer_name_to_kvcache_index.items()),
223
191
  lora_metadata,
224
- intermediate_tensors,
225
- is_first_rank,
226
- is_last_rank,
227
192
  num_tokens=num_tokens,
228
193
  )
229
194
 
@@ -274,7 +239,6 @@ class CompilationManager:
274
239
  )
275
240
 
276
241
  def _precompile_backbone_text_only(self) -> None:
277
- hidden_size = self.runner.model_config.get_hidden_size()
278
242
  for num_tokens in self.runner.num_tokens_paddings:
279
243
  dp_sharding = NamedSharding(
280
244
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
@@ -284,28 +248,10 @@ class CompilationManager:
284
248
  dp_sharding)
285
249
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
286
250
  dp_sharding)
287
- is_first_rank = self.runner.is_first_rank
288
- is_last_rank = self.runner.is_last_rank
289
- if is_first_rank:
290
- intermediate_tensors = None
291
- else:
292
- hidden_states = self._create_dummy_tensor(
293
- (num_tokens, hidden_size), jnp.bfloat16)
294
- residual = self._create_dummy_tensor((num_tokens, hidden_size),
295
- jnp.bfloat16)
296
- intermediate_tensors = JaxIntermediateTensors(
297
- tensors={
298
- "hidden_states": hidden_states,
299
- "residual": residual
300
- })
301
- self._precompile_backbone_helper(
302
- f"worker{self.runner.rank} backbone",
303
- input_ids=input_ids,
304
- positions=positions,
305
- inputs_embeds=None,
306
- intermediate_tensors=intermediate_tensors,
307
- is_first_rank=is_first_rank,
308
- is_last_rank=is_last_rank)
251
+ self._precompile_backbone_helper("backbone",
252
+ input_ids=input_ids,
253
+ positions=positions,
254
+ inputs_embeds=None)
309
255
 
310
256
  def _precompile_backbone_with_inputs_embeds(self) -> None:
311
257
  hidden_size = self.runner.model_config.get_hidden_size()
@@ -319,28 +265,10 @@ class CompilationManager:
319
265
  else:
320
266
  positions = self._create_dummy_tensor((num_tokens, ),
321
267
  jnp.int32)
322
- is_first_rank = self.runner.is_first_rank
323
- is_last_rank = self.runner.is_last_rank
324
- if not is_first_rank:
325
- hidden_states = self._create_dummy_tensor(
326
- (num_tokens, hidden_size), jnp.bfloat16)
327
- residual = self._create_dummy_tensor((num_tokens, hidden_size),
328
- jnp.bfloat16)
329
- intermediate_tensors = JaxIntermediateTensors(
330
- tensors={
331
- "hidden_states": hidden_states,
332
- "residual": residual
333
- })
334
- else:
335
- intermediate_tensors = None
336
- self._precompile_backbone_helper(
337
- f"worker{self.runner.rank} backbone with embeds",
338
- input_ids=None,
339
- positions=positions,
340
- inputs_embeds=inputs_embeds,
341
- intermediate_tensors=intermediate_tensors,
342
- is_first_rank=is_first_rank,
343
- is_last_rank=is_last_rank)
268
+ self._precompile_backbone_helper("backbone with embeds",
269
+ input_ids=None,
270
+ positions=positions,
271
+ inputs_embeds=inputs_embeds)
344
272
 
345
273
  def _precompile_select_from_array_helper(
346
274
  self,
@@ -404,23 +332,20 @@ class CompilationManager:
404
332
  index_paddings = self.runner.num_reqs_paddings
405
333
  dp_sharding = NamedSharding(self.runner.mesh,
406
334
  PartitionSpec(ShardingAxisName.ATTN_DATA))
407
- hidden_states_sharding = NamedSharding(
408
- self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
409
335
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
410
336
  self._precompile_select_from_array_helper(
411
- name=f"worker{self.runner.rank} select all logits",
337
+ name="select all logits",
412
338
  source_paddings=self.runner.num_tokens_paddings,
413
339
  indices_paddings=index_paddings,
414
340
  hidden_dim=hsize,
415
- input_sharding=hidden_states_sharding,
341
+ input_sharding=dp_sharding,
416
342
  indices_sharding=dp_sharding if dp_size > 1 else None,
417
343
  )
418
344
 
419
345
  if self.runner.speculative_config:
420
346
  vocab_size = self.runner.model_config.get_vocab_size()
421
347
  self._precompile_select_from_array_helper(
422
- name=
423
- f"worker{self.runner.rank} select bonus tokens for spec decoding",
348
+ name="select bonus tokens for spec decoding",
424
349
  source_paddings=self.runner.num_logits_paddings,
425
350
  indices_paddings=self.runner.num_reqs_paddings,
426
351
  hidden_dim=vocab_size,
@@ -428,8 +353,7 @@ class CompilationManager:
428
353
  PartitionSpec(None, "model")),
429
354
  )
430
355
  self._precompile_select_from_array_helper(
431
- name=
432
- f"worker{self.runner.rank} select target tokens for spec decoding",
356
+ name="select target tokens for spec decoding",
433
357
  source_paddings=self.runner.num_logits_paddings,
434
358
  indices_paddings=self.runner.num_logits_paddings,
435
359
  hidden_dim=vocab_size,
@@ -452,7 +376,7 @@ class CompilationManager:
452
376
  np.array([num_reqs], dtype=np.int32)):
453
377
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
454
378
  self._run_compilation(
455
- f"worker{self.runner.rank} compute_logits",
379
+ "compute_logits",
456
380
  self.runner.compute_logits_fn,
457
381
  self.runner.state,
458
382
  hidden_states,
@@ -494,7 +418,7 @@ class CompilationManager:
494
418
  do_sampling=do_sampling,
495
419
  )
496
420
  self._run_compilation(
497
- f"worker{self.runner.rank} sample",
421
+ "sample",
498
422
  sample,
499
423
  self.runner.rng_params_for_sampling,
500
424
  self.runner.mesh,
@@ -535,7 +459,7 @@ class CompilationManager:
535
459
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
536
460
  token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
537
461
  self._run_compilation(
538
- f"worker{self.runner.rank} gather_logprobs",
462
+ "gather_logprobs",
539
463
  self.runner._compute_and_gather_logprobs,
540
464
  logits,
541
465
  token_ids,
@@ -587,7 +511,7 @@ class CompilationManager:
587
511
  do_sampling=do_sampling)
588
512
 
589
513
  self._run_compilation(
590
- f"worker{self.runner.rank} {compilation_name}",
514
+ compilation_name,
591
515
  self.runner.rejection_sampler,
592
516
  draft_token_ids,
593
517
  num_draft_tokens,
@@ -604,9 +528,7 @@ class CompilationManager:
604
528
  def _precompile_eagle3_helpers(self) -> None:
605
529
  logger.info(
606
530
  "Compiling eagle3 jitted helpers with different input shapes.")
607
- target_hidden_size = self.runner.model_config.get_hidden_size()
608
- draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
609
- )
531
+ hidden_size = self.runner.model_config.get_hidden_size()
610
532
  dtype = self.runner.model_config.dtype
611
533
 
612
534
  num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
@@ -653,11 +575,10 @@ class CompilationManager:
653
575
 
654
576
  for num_logits in self.runner.num_logits_paddings:
655
577
  hidden_states = self._create_dummy_tensor(
656
- (num_logits, draft_hidden_size), jnp.bfloat16)
578
+ (num_logits, hidden_size), jnp.bfloat16)
657
579
  self._run_compilation(
658
580
  "eagle3_get_draft_token_ids",
659
581
  self.runner.drafter._get_draft_token_ids,
660
- self.runner.drafter.state,
661
582
  hidden_states,
662
583
  num_logits=num_logits,
663
584
  )
@@ -665,8 +586,8 @@ class CompilationManager:
665
586
  input_ids_loop = self._create_dummy_tensor(
666
587
  (self.runner.max_num_reqs, ), jnp.int32,
667
588
  NamedSharding(self.runner.mesh, PartitionSpec()))
668
- draft_hidden_state_loop = self._create_dummy_tensor(
669
- (self.runner.max_num_reqs, draft_hidden_size), dtype,
589
+ target_hidden_state_loop = self._create_dummy_tensor(
590
+ (self.runner.max_num_reqs, hidden_size), dtype,
670
591
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
671
592
  next_token_ids = self._create_dummy_tensor(
672
593
  (self.runner.max_num_reqs, ), jnp.int32)
@@ -674,12 +595,9 @@ class CompilationManager:
674
595
  (self.runner.max_num_reqs, ), jnp.int32)
675
596
  for num_tokens in self.runner.num_tokens_paddings:
676
597
  aux_hidden_states = [
677
- self._create_dummy_tensor((num_tokens, target_hidden_size),
678
- dtype),
679
- self._create_dummy_tensor((num_tokens, target_hidden_size),
680
- dtype),
681
- self._create_dummy_tensor((num_tokens, target_hidden_size),
682
- dtype),
598
+ self._create_dummy_tensor((num_tokens, hidden_size), dtype),
599
+ self._create_dummy_tensor((num_tokens, hidden_size), dtype),
600
+ self._create_dummy_tensor((num_tokens, hidden_size), dtype),
683
601
  ]
684
602
 
685
603
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -702,23 +620,23 @@ class CompilationManager:
702
620
  num_reqs,
703
621
  ):
704
622
  target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
705
- self.runner.drafter.state, token_indices, query_start_loc,
706
- seq_lens, input_ids, aux_hidden_states, attention_metadata,
707
- next_token_ids, num_reqs)
623
+ token_indices, query_start_loc, seq_lens, input_ids,
624
+ aux_hidden_states, attention_metadata, next_token_ids,
625
+ num_reqs)
708
626
  return target_hidden_states, input_ids, last_token_indices
709
627
 
710
628
  input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
711
629
  aux_hidden_states = [
712
630
  self._create_dummy_tensor(
713
- (num_tokens, target_hidden_size), jnp.bfloat16,
631
+ (num_tokens, hidden_size), jnp.bfloat16,
714
632
  NamedSharding(self.runner.mesh, PartitionSpec(None,
715
633
  None))),
716
634
  self._create_dummy_tensor(
717
- (num_tokens, target_hidden_size), jnp.bfloat16,
635
+ (num_tokens, hidden_size), jnp.bfloat16,
718
636
  NamedSharding(self.runner.mesh, PartitionSpec(None,
719
637
  None))),
720
638
  self._create_dummy_tensor(
721
- (num_tokens, target_hidden_size), jnp.bfloat16,
639
+ (num_tokens, hidden_size), jnp.bfloat16,
722
640
  NamedSharding(self.runner.mesh, PartitionSpec(None,
723
641
  None))),
724
642
  ]
@@ -750,17 +668,17 @@ class CompilationManager:
750
668
  state,
751
669
  kv_caches,
752
670
  input_ids,
753
- draft_hidden_states,
671
+ target_hidden_states,
754
672
  attention_metadata,
755
673
  ):
756
674
  kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
757
- state, kv_caches, input_ids, draft_hidden_states,
675
+ state, kv_caches, input_ids, target_hidden_states,
758
676
  attention_metadata)
759
677
  self.runner.kv_caches = kv_caches
760
678
  return hidden_states
761
679
 
762
- draft_hidden_states = self._create_dummy_tensor(
763
- (num_tokens, draft_hidden_size), dtype,
680
+ target_hidden_states = self._create_dummy_tensor(
681
+ (num_tokens, hidden_size), dtype,
764
682
  NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
765
683
  input_ids = self._create_dummy_tensor(
766
684
  (num_tokens, ), jnp.int32,
@@ -771,7 +689,7 @@ class CompilationManager:
771
689
  self.runner.drafter.state,
772
690
  self.runner.kv_caches,
773
691
  input_ids,
774
- draft_hidden_states,
692
+ target_hidden_states,
775
693
  attention_metadata,
776
694
  num_tokens=num_tokens,
777
695
  )
@@ -781,7 +699,6 @@ class CompilationManager:
781
699
  self._run_compilation(
782
700
  "eagle3_prepare_hidden_states_and_input_ids",
783
701
  self.runner.drafter._prepare_hidden_states_and_input_ids,
784
- self.runner.drafter.state,
785
702
  aux_hidden_states,
786
703
  query_start_loc,
787
704
  target_token_ids,
@@ -804,19 +721,18 @@ class CompilationManager:
804
721
  self.runner.drafter.state,
805
722
  self.runner.kv_caches,
806
723
  input_ids_loop,
807
- draft_hidden_state_loop,
724
+ target_hidden_state_loop,
808
725
  attention_metadata,
809
726
  num_tokens=num_tokens,
810
727
  )
811
728
 
812
729
  hidden_states = self._create_dummy_tensor(
813
- (num_tokens, draft_hidden_size), jnp.bfloat16,
730
+ (num_tokens, hidden_size), jnp.bfloat16,
814
731
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
815
732
 
816
733
  self._run_compilation(
817
734
  "eagle3_select_inputs_for_loop_speculation",
818
735
  self.runner.drafter._select_inputs_for_loop_speculation,
819
- self.runner.drafter.state,
820
736
  positions,
821
737
  hidden_states,
822
738
  hidden_states,
@@ -827,7 +743,6 @@ class CompilationManager:
827
743
  self._run_compilation(
828
744
  "eagle3_select_draft_token_ids",
829
745
  self.runner.drafter._select_draft_token_ids,
830
- self.runner.drafter.state,
831
746
  hidden_states,
832
747
  last_token_indices,
833
748
  num_tokens=num_tokens,
@@ -9,7 +9,7 @@ from torchax.ops.mappings import t2j_dtype
9
9
 
10
10
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
11
11
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
12
- from tpu_inference.layers.common.sharding import ShardingAxisName
12
+ from tpu_inference.layers.jax.sharding import ShardingAxisName
13
13
  from tpu_inference.logger import init_logger
14
14
 
15
15
  logger = init_logger(__name__)
@@ -82,7 +82,7 @@ def create_kv_caches(
82
82
  ShardingAxisName.ATTN_HEAD))
83
83
 
84
84
  def _allocate() -> jax.Array:
85
- return jnp.zeros(
85
+ return jnp.empty(
86
86
  shape=cache_shape,
87
87
  dtype=cache_dtype,
88
88
  )