tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (37) hide show
  1. tests/test_envs.py +182 -0
  2. tests/test_utils.py +23 -14
  3. tpu_inference/core/core_tpu.py +17 -9
  4. tpu_inference/executors/ray_distributed_executor.py +24 -11
  5. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +33 -10
  6. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
  7. tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
  8. tpu_inference/layers/common/quant_methods.py +8 -0
  9. tpu_inference/layers/jax/attention/attention.py +1 -1
  10. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  11. tpu_inference/layers/jax/sample/sampling.py +2 -2
  12. tpu_inference/layers/vllm/attention.py +1 -1
  13. tpu_inference/layers/vllm/quantization/__init__.py +7 -3
  14. tpu_inference/layers/vllm/quantization/awq.py +4 -3
  15. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
  16. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  17. tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
  18. tpu_inference/models/common/model_loader.py +3 -2
  19. tpu_inference/models/jax/llama3.py +2 -2
  20. tpu_inference/models/jax/phi3.py +1 -1
  21. tpu_inference/models/jax/qwen2.py +1 -1
  22. tpu_inference/models/jax/qwen2_5_vl.py +2 -2
  23. tpu_inference/models/jax/qwen3.py +1 -1
  24. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
  25. tpu_inference/platforms/tpu_platform.py +12 -5
  26. tpu_inference/runner/compilation_manager.py +4 -2
  27. tpu_inference/runner/kv_cache.py +1 -1
  28. tpu_inference/runner/tpu_runner.py +31 -7
  29. tpu_inference/utils.py +2 -2
  30. tpu_inference/worker/tpu_worker.py +1 -1
  31. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +1 -1
  32. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +37 -34
  33. /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
  34. /tpu_inference/layers/{jax → common}/sharding.py +0 -0
  35. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  36. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  37. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
tests/test_envs.py ADDED
@@ -0,0 +1,182 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3
+
4
+ import pytest
5
+
6
+ import tpu_inference.envs as envs
7
+ from tpu_inference.envs import enable_envs_cache, environment_variables
8
+
9
+
10
+ def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch):
11
+ assert envs.JAX_PLATFORMS == ""
12
+ assert envs.PHASED_PROFILING_DIR == ""
13
+ monkeypatch.setenv("JAX_PLATFORMS", "tpu")
14
+ monkeypatch.setenv("PHASED_PROFILING_DIR", "/tmp/profiling")
15
+ assert envs.JAX_PLATFORMS == "tpu"
16
+ assert envs.PHASED_PROFILING_DIR == "/tmp/profiling"
17
+
18
+ assert envs.TPU_NAME is None
19
+ assert envs.TPU_ACCELERATOR_TYPE is None
20
+ monkeypatch.setenv("TPU_NAME", "my-tpu")
21
+ monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v5litepod-16")
22
+ assert envs.TPU_NAME == "my-tpu"
23
+ assert envs.TPU_ACCELERATOR_TYPE == "v5litepod-16"
24
+
25
+ # __getattr__ is not decorated with functools.cache
26
+ assert not hasattr(envs.__getattr__, "cache_info")
27
+
28
+
29
+ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
30
+ monkeypatch.setenv("JAX_PLATFORMS", "tpu")
31
+ monkeypatch.setenv("TPU_NAME", "my-tpu")
32
+
33
+ # __getattr__ is not decorated with functools.cache
34
+ assert not hasattr(envs.__getattr__, "cache_info")
35
+
36
+ enable_envs_cache()
37
+
38
+ # __getattr__ is decorated with functools.cache
39
+ assert hasattr(envs.__getattr__, "cache_info")
40
+ start_hits = envs.__getattr__.cache_info().hits
41
+
42
+ # 2 more hits due to JAX_PLATFORMS and TPU_NAME accesses
43
+ assert envs.JAX_PLATFORMS == "tpu"
44
+ assert envs.TPU_NAME == "my-tpu"
45
+ assert envs.__getattr__.cache_info().hits == start_hits + 2
46
+
47
+ # All environment variables are cached
48
+ for environment_variable in environment_variables:
49
+ envs.__getattr__(environment_variable)
50
+ assert envs.__getattr__.cache_info(
51
+ ).hits == start_hits + 2 + len(environment_variables)
52
+
53
+ # Reset envs.__getattr__ back to non-cached version to
54
+ # avoid affecting other tests
55
+ envs.__getattr__ = envs.__getattr__.__wrapped__
56
+
57
+
58
+ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
59
+ # Test SKIP_JAX_PRECOMPILE (default False)
60
+ assert envs.SKIP_JAX_PRECOMPILE is False
61
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
62
+ assert envs.SKIP_JAX_PRECOMPILE is True
63
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
64
+ assert envs.SKIP_JAX_PRECOMPILE is False
65
+
66
+ # Test NEW_MODEL_DESIGN (default False)
67
+ assert envs.NEW_MODEL_DESIGN is False
68
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
69
+ assert envs.NEW_MODEL_DESIGN is True
70
+
71
+ # Test USE_MOE_EP_KERNEL (default False)
72
+ assert envs.USE_MOE_EP_KERNEL is False
73
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
74
+ assert envs.USE_MOE_EP_KERNEL is True
75
+
76
+
77
+ def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
78
+ assert envs.PYTHON_TRACER_LEVEL == 1
79
+ monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
80
+ assert envs.PYTHON_TRACER_LEVEL == 3
81
+ monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
82
+ assert envs.PYTHON_TRACER_LEVEL == 0
83
+
84
+
85
+ def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
86
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
87
+ assert envs.TPU_MULTIHOST_BACKEND == "grpc"
88
+
89
+ monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
90
+ assert envs.MODEL_IMPL_TYPE == "flax_nnx"
91
+
92
+
93
+ def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
94
+ monkeypatch.delenv("JAX_PLATFORMS", raising=False)
95
+ monkeypatch.delenv("PREFILL_SLICES", raising=False)
96
+ monkeypatch.delenv("DECODE_SLICES", raising=False)
97
+
98
+ assert envs.JAX_PLATFORMS == ""
99
+ assert envs.PREFILL_SLICES == ""
100
+ assert envs.DECODE_SLICES == ""
101
+ assert envs.PHASED_PROFILING_DIR == ""
102
+
103
+
104
+ def test_none_default_env_vars(monkeypatch: pytest.MonkeyPatch):
105
+ monkeypatch.delenv("TPU_ACCELERATOR_TYPE", raising=False)
106
+ monkeypatch.delenv("TPU_NAME", raising=False)
107
+ monkeypatch.delenv("TPU_WORKER_ID", raising=False)
108
+
109
+ assert envs.TPU_ACCELERATOR_TYPE is None
110
+ assert envs.TPU_NAME is None
111
+ assert envs.TPU_WORKER_ID is None
112
+
113
+
114
+ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
115
+ assert envs.RAY_USAGE_STATS_ENABLED == "0"
116
+ monkeypatch.setenv("RAY_USAGE_STATS_ENABLED", "1")
117
+ assert envs.RAY_USAGE_STATS_ENABLED == "1"
118
+
119
+ assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
120
+ monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
121
+ assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
122
+
123
+
124
+ def test_invalid_attribute_raises_error():
125
+ with pytest.raises(AttributeError,
126
+ match="has no attribute 'NONEXISTENT_VAR'"):
127
+ _ = envs.NONEXISTENT_VAR
128
+
129
+
130
+ def test_dir_returns_all_env_vars():
131
+ env_vars = envs.__dir__()
132
+ assert isinstance(env_vars, list)
133
+ assert len(env_vars) == len(environment_variables)
134
+ assert "JAX_PLATFORMS" in env_vars
135
+ assert "TPU_NAME" in env_vars
136
+ assert "SKIP_JAX_PRECOMPILE" in env_vars
137
+ assert "MODEL_IMPL_TYPE" in env_vars
138
+
139
+
140
+ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
141
+ monkeypatch.setenv("TPU_WORKER_ID", "0")
142
+ assert envs.TPU_WORKER_ID == "0"
143
+
144
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
145
+ assert envs.TPU_MULTIHOST_BACKEND == "grpc"
146
+
147
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
148
+ assert envs.TPU_MULTIHOST_BACKEND == "xla"
149
+
150
+
151
+ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
152
+ monkeypatch.setenv("PREFILL_SLICES", "0,1,2,3")
153
+ assert envs.PREFILL_SLICES == "0,1,2,3"
154
+
155
+ monkeypatch.setenv("DECODE_SLICES", "4,5,6,7")
156
+ assert envs.DECODE_SLICES == "4,5,6,7"
157
+
158
+
159
+ def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
160
+ monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
161
+ assert envs.MODEL_IMPL_TYPE == "flax_nnx"
162
+
163
+
164
+ def test_cache_preserves_values_across_env_changes(
165
+ monkeypatch: pytest.MonkeyPatch):
166
+ monkeypatch.setenv("JAX_PLATFORMS", "tpu")
167
+
168
+ enable_envs_cache()
169
+
170
+ assert envs.JAX_PLATFORMS == "tpu"
171
+
172
+ # Change environment variable
173
+ monkeypatch.setenv("JAX_PLATFORMS", "cpu")
174
+
175
+ # Cached value should still be "tpu"
176
+ assert envs.JAX_PLATFORMS == "tpu"
177
+
178
+ # Reset envs.__getattr__ back to non-cached version
179
+ envs.__getattr__ = envs.__getattr__.__wrapped__
180
+
181
+ # Now it should reflect the new value
182
+ assert envs.JAX_PLATFORMS == "cpu"
tests/test_utils.py CHANGED
@@ -75,25 +75,34 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
75
75
  mock_device2 = MagicMock()
76
76
  devices = [mock_device1, mock_device2]
77
77
 
78
- # Create mock device buffers
79
- mock_buffer1_dev1 = MagicMock()
80
- mock_buffer1_dev1.device = mock_device1
81
- mock_buffer1_dev1.nbytes = 2000 # 2000 bytes on device1
78
+ # Create mock addressable shards with data property
79
+ mock_data1_dev1 = MagicMock()
80
+ mock_data1_dev1.device = mock_device1
81
+ mock_data1_dev1.nbytes = 2000 # 2000 bytes on device1
82
82
 
83
- mock_buffer1_dev2 = MagicMock()
84
- mock_buffer1_dev2.device = mock_device2
85
- mock_buffer1_dev2.nbytes = 2000 # 2000 bytes on device2
83
+ mock_data1_dev2 = MagicMock()
84
+ mock_data1_dev2.device = mock_device2
85
+ mock_data1_dev2.nbytes = 2000 # 2000 bytes on device2
86
86
 
87
- mock_buffer2_dev1 = MagicMock()
88
- mock_buffer2_dev1.device = mock_device1
89
- mock_buffer2_dev1.nbytes = 1000 # 1000 bytes on device1
87
+ mock_data2_dev1 = MagicMock()
88
+ mock_data2_dev1.device = mock_device1
89
+ mock_data2_dev1.nbytes = 1000 # 1000 bytes on device1
90
90
 
91
- # Create mock arrays with device buffers
91
+ mock_shard1_dev1 = MagicMock()
92
+ mock_shard1_dev1.data = mock_data1_dev1
93
+
94
+ mock_shard1_dev2 = MagicMock()
95
+ mock_shard1_dev2.data = mock_data1_dev2
96
+
97
+ mock_shard2_dev1 = MagicMock()
98
+ mock_shard2_dev1.data = mock_data2_dev1
99
+
100
+ # Create mock arrays with addressable_shards
92
101
  mock_array1 = MagicMock()
93
- mock_array1.device_buffers = [mock_buffer1_dev1, mock_buffer1_dev2]
102
+ mock_array1.addressable_shards = [mock_shard1_dev1, mock_shard1_dev2]
94
103
 
95
104
  mock_array2 = MagicMock()
96
- mock_array2.device_buffers = [mock_buffer2_dev1]
105
+ mock_array2.addressable_shards = [mock_shard2_dev1]
97
106
 
98
107
  mock_live_arrays.return_value = [mock_array1, mock_array2]
99
108
 
@@ -159,7 +168,7 @@ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
159
168
  "head_dim, expected_padded_head_dim",
160
169
  [
161
170
  (1, 128),
162
- (64, 128),
171
+ (64, 64),
163
172
  (127, 128),
164
173
  (128, 128),
165
174
  (129, 256),
@@ -29,6 +29,7 @@ from vllm.v1.request import Request, RequestStatus
29
29
 
30
30
  from tpu_inference import utils as common_utils
31
31
  from tpu_inference.core import disagg_executor, disagg_utils
32
+ from tpu_inference.runner.tpu_runner import AsyncTPUModelRunnerOutput
32
33
  # ======================================================================================
33
34
  # Imports for _DisaggOrchestrator (decoupled from vLLM)
34
35
  # ======================================================================================
@@ -186,6 +187,8 @@ class _DisaggOrchestrator:
186
187
  if model_output is None:
187
188
  model_output = prefill_engine.model_executor.sample_tokens(
188
189
  grammar_output)
190
+ if isinstance(model_output, AsyncTPUModelRunnerOutput):
191
+ model_output = model_output.get_output()
189
192
 
190
193
  if scheduler_output.total_num_scheduled_tokens > 0:
191
194
  logger.debug(f"Prefill result: {model_output}")
@@ -218,15 +221,16 @@ class _DisaggOrchestrator:
218
221
  f"request-{req_id}: tokens={request.all_token_ids} after prefill"
219
222
  )
220
223
  # Remove request from the prefill engine.
224
+ if req_id in prefill_engine.scheduler.requests:
225
+ request = prefill_engine.scheduler.requests[req_id]
226
+ prefill_engine.scheduler.running.remove(request)
227
+ prefill_engine.scheduler.encoder_cache_manager.free(
228
+ request)
221
229
 
222
- request = prefill_engine.scheduler.requests[req_id]
223
- prefill_engine.scheduler.running.remove(request)
224
- prefill_engine.scheduler.encoder_cache_manager.free(
225
- request)
230
+ prefill_engine.scheduler.kv_cache_manager.free(
231
+ request)
226
232
 
227
- prefill_engine.scheduler.kv_cache_manager.free(request)
228
-
229
- prefill_engine.scheduler.requests.pop(req_id)
233
+ prefill_engine.scheduler.requests.pop(req_id)
230
234
 
231
235
  for output in (engine_core_outputs.items()
232
236
  if engine_core_outputs else ()):
@@ -335,8 +339,10 @@ class _DisaggOrchestrator:
335
339
  new_block_ids = kv_cache_manager.get_block_ids(req_id)
336
340
  logger.debug(
337
341
  f"inserting {req_id} new_block_ids {new_block_ids}")
338
- assert (len(new_block_ids[0]) == math.ceil(
339
- prompt_tokens / self._config.cache_config.block_size))
342
+ if len(new_block_ids[0]) != math.ceil(
343
+ prompt_tokens / self._config.cache_config.block_size):
344
+ logger.warning("Running out of blocks in decode engine! ")
345
+ break
340
346
 
341
347
  decode_engine.model_executor.driver_worker.model_runner.insert_request_with_kv_cache(
342
348
  vllm_request, kv_cache, new_block_ids)
@@ -366,6 +372,8 @@ class _DisaggOrchestrator:
366
372
  if model_output is None:
367
373
  model_output = decode_engine.model_executor.sample_tokens(
368
374
  grammar_output)
375
+ if isinstance(model_output, AsyncTPUModelRunnerOutput):
376
+ model_output = model_output.get_output()
369
377
 
370
378
  if scheduler_output.total_num_scheduled_tokens > 0:
371
379
  logger.debug(f"Decode result: {model_output}")
@@ -131,10 +131,17 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
131
131
  f"current platform {current_platform.device_name} does not "
132
132
  "support ray.")
133
133
 
134
- placement_group_specs: List[Dict[str, float]] = [{
135
- device_str:
136
- node['Resources'][device_str]
137
- } for node in ray.nodes()]
134
+ pp_size = self.parallel_config.pipeline_parallel_size
135
+ placement_group_specs: List[Dict[str, float]] = []
136
+ if pp_size == 1:
137
+ placement_group_specs = [{
138
+ device_str: node['Resources'][device_str]
139
+ } for node in ray.nodes()]
140
+ else:
141
+ num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
142
+ placement_group_specs = [{
143
+ device_str: num_devices_per_pp_rank
144
+ } for _ in range(pp_size)]
138
145
 
139
146
  # vLLM engine is also a worker to execute model with an accelerator,
140
147
  # so it requires to have the device in a current node. Check if
@@ -329,6 +336,8 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
329
336
  all_kwargs = []
330
337
  for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
331
338
  local_rank = node_workers[node_id].index(rank)
339
+ ip = sorted_worker_metadata[rank].ip
340
+ prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
332
341
  kwargs = dict(
333
342
  vllm_config=self.vllm_config,
334
343
  local_rank=local_rank,
@@ -336,22 +345,26 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
336
345
  distributed_init_method=distributed_init_method,
337
346
  is_driver_worker=(not self.parallel_config)
338
347
  or (rank % self.parallel_config.tensor_parallel_size == 0),
348
+ ip=ip,
349
+ prev_worker_ip=prev_ip,
339
350
  )
340
351
  all_kwargs.append(kwargs)
341
352
  self.collective_rpc("init_worker", args=(all_kwargs, ))
342
353
  self.collective_rpc("init_device")
354
+ if self.parallel_config.pipeline_parallel_size > 1:
355
+ self._run_workers("initialize_pp_transfer_connect")
343
356
  self.collective_rpc("load_model")
344
357
 
345
358
  if self.use_ray_spmd_worker:
346
359
  for pp_rank in range(self.parallel_config.pipeline_parallel_size):
347
360
  self.pp_tp_workers.append([])
348
- for tp_rank in range(
349
- int(self.parallel_config.tensor_parallel_size //
350
- num_tpu_per_worker)):
351
- # PP=2, TP=4
352
- # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
353
- rank = (pp_rank * self.parallel_config.tensor_parallel_size
354
- ) + tp_rank
361
+ num_tp_workers = int(
362
+ self.parallel_config.tensor_parallel_size //
363
+ num_tpu_per_worker)
364
+ for tp_rank in range(num_tp_workers):
365
+ # PP=2, TP=4, num_tpu_per_worker=2
366
+ # pp_tp_workers = [[0, 1], [2, 3]]
367
+ rank = (pp_rank * num_tp_workers) + tp_rank
355
368
  assert len(self.pp_tp_workers[pp_rank]) == tp_rank
356
369
  assert pp_rank < len(self.pp_tp_workers)
357
370
  self.pp_tp_workers[pp_rank].append(self.workers[rank])
@@ -317,6 +317,20 @@ def _ragged_paged_attention_kernel(
317
317
  q_len = q_end - q_start
318
318
  kv_len = kv_lens_ref[seq_idx]
319
319
 
320
+ bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
321
+ kv_len - sliding_window, 0) // bkv_sz
322
+
323
+ if sliding_window is None:
324
+ next_bkv_idx_start = 0
325
+ else:
326
+
327
+ def get_next_bkv_idx_start():
328
+ next_kv_len = kv_lens_ref[seq_idx + 1]
329
+ return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
330
+
331
+ next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
332
+ get_next_bkv_idx_start, lambda: 0)
333
+
320
334
  def debug_print(msg, *args):
321
335
  if debug_mode:
322
336
  pl.debug_print(msg, *args)
@@ -353,8 +367,8 @@ def _ragged_paged_attention_kernel(
353
367
  head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
354
368
 
355
369
  def load_with_init(ref, init_val):
356
- return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
357
- ref[...])
370
+ return jnp.where(bkv_idx == bkv_idx_start,
371
+ jnp.full_like(ref, init_val), ref[...])
358
372
 
359
373
  # Follow FlashAttention-2 forward pass.
360
374
  if q_scale is not None:
@@ -378,9 +392,6 @@ def _ragged_paged_attention_kernel(
378
392
  num_q_heads_per_kv_head)
379
393
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
380
394
  mask = q_span < k_span
381
- # TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
382
- if sliding_window is not None:
383
- mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
384
395
 
385
396
  if soft_cap is not None:
386
397
  s = soft_cap * jnp.tanh(s / soft_cap)
@@ -391,7 +402,8 @@ def _ragged_paged_attention_kernel(
391
402
  sinks = attention_sink_ref[kv_head_idx]
392
403
  actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head
393
404
  m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0)
394
- m_prev = jnp.where(bkv_idx == 0, m_prev_init, head_m_ref[...])
405
+ m_prev = jnp.where(bkv_idx == bkv_idx_start, m_prev_init,
406
+ head_m_ref[...])
395
407
  else:
396
408
  m_prev = load_with_init(head_m_ref, -jnp.inf)
397
409
 
@@ -719,12 +731,19 @@ def _ragged_paged_attention_kernel(
719
731
  def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
720
732
  next_bkv_idx = bkv_idx + 1
721
733
  is_last_bkv = next_bkv_idx == num_bkv
722
- next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
723
734
  next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
724
735
  is_last_bq = next_bq_idx == num_bq
725
736
  next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
726
737
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
727
738
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
739
+
740
+ next_bkv_idx = lax.select(
741
+ is_last_bkv,
742
+ lax.select(
743
+ is_last_bq,
744
+ next_bkv_idx_start,
745
+ bkv_idx_start,
746
+ ), next_bkv_idx)
728
747
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
729
748
 
730
749
  def compute_with_bq(bq_idx, _):
@@ -759,7 +778,7 @@ def _ragged_paged_attention_kernel(
759
778
  next_bkv_sem_idx)
760
779
 
761
780
  # Wait for cur bq if not ready yet
762
- @pl.when(bkv_idx == 0)
781
+ @pl.when(bkv_idx == bkv_idx_start)
763
782
  def wait_cur_bq():
764
783
  wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
765
784
 
@@ -808,7 +827,11 @@ def _ragged_paged_attention_kernel(
808
827
  kv_head_idx=kv_head_idx,
809
828
  )
810
829
 
811
- lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
830
+ lax.fori_loop(bkv_idx_start,
831
+ num_bkv,
832
+ compute_with_bkv,
833
+ None,
834
+ unroll=False)
812
835
 
813
836
  # Load acc and calculate final output.
814
837
  acc = acc_ref[...]
@@ -838,7 +861,7 @@ def _ragged_paged_attention_kernel(
838
861
  @pl.when(seq_idx == 0)
839
862
  def prologue():
840
863
  start_fetch_bq(0, 0, 0)
841
- start_fetch_bkv(0, 0, 0)
864
+ start_fetch_bkv(0, bkv_idx_start, 0)
842
865
 
843
866
  @pl.when(seq_idx < decode_end)
844
867
  def process_decode():
@@ -1231,6 +1231,13 @@ TUNED_BLOCK_SIZES = {
1231
1231
  },
1232
1232
  }
1233
1233
  },
1234
+ 16: {
1235
+ 'q_bfloat16_kv_bfloat16': {
1236
+ 'q_head-8_kv_head-1_head-128': {
1237
+ 262144: (128, 256),
1238
+ }
1239
+ }
1240
+ },
1234
1241
  },
1235
1242
  'TPU v5e': {
1236
1243
  128: {
@@ -17,7 +17,7 @@ import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
17
17
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
18
18
  from tpu_inference.kernels.flash_attention.kernel import flash_attention
19
19
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20
- from tpu_inference.layers.jax.sharding import ShardingAxisName
20
+ from tpu_inference.layers.common.sharding import ShardingAxisName
21
21
  from tpu_inference.utils import get_megacore
22
22
 
23
23
  MAX_ALLOWED_PAGE_INDICES_N = (
@@ -0,0 +1,8 @@
1
+ UNQUANTIZED = "unquantized"
2
+ MXFP4 = "mxfp4"
3
+ AWQ = "awq"
4
+ COMPRESSED_TENSORS = "compressed-tensors"
5
+
6
+
7
+ def get_tpu_quant_method(quant_method: str) -> str:
8
+ return "tpu-" + quant_method
@@ -13,9 +13,9 @@ from tpu_inference import utils
13
13
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
14
14
  ragged_paged_attention
15
15
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
16
+ from tpu_inference.layers.common.sharding import ShardingAxisName
16
17
  from tpu_inference.layers.jax.base import create_param
17
18
  from tpu_inference.layers.jax.rope_interface import apply_rope
18
- from tpu_inference.layers.jax.sharding import ShardingAxisName
19
19
 
20
20
  KVCache = Tuple[jax.Array, jax.Array]
21
21
 
@@ -12,7 +12,7 @@ import jax
12
12
  import jax.numpy as jnp
13
13
  import numpy as np
14
14
 
15
- from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
15
+ from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
16
16
  from tpu_inference.layers.jax.sample.sampling_metadata import \
17
17
  TPUSupportedSamplingMetadata
18
18
 
@@ -6,10 +6,10 @@ from jax.sharding import Mesh, NamedSharding
6
6
  from jax.sharding import PartitionSpec as P
7
7
  from vllm.v1.outputs import LogprobsTensors
8
8
 
9
- from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
9
+ from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
10
+ from tpu_inference.layers.common.sharding import ShardingAxisName
10
11
  from tpu_inference.layers.jax.sample.sampling_metadata import \
11
12
  TPUSupportedSamplingMetadata
12
- from tpu_inference.layers.jax.sharding import ShardingAxisName
13
13
 
14
14
  _SAMPLING_EPS = 1e-5
15
15
 
@@ -13,8 +13,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
13
  AttentionLayer, AttentionType)
14
14
 
15
15
  from tpu_inference import utils
16
+ from tpu_inference.layers.common.attention_interface import attention
16
17
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
17
- from tpu_inference.layers.jax.attention_interface import attention
18
18
  from tpu_inference.logger import init_logger
19
19
  from tpu_inference.models.vllm.vllm_model_wrapper_context import \
20
20
  get_vllm_model_wrapper_context
@@ -5,10 +5,12 @@ from vllm.config import VllmConfig
5
5
  from vllm.model_executor.layers.quantization.base_config import \
6
6
  QuantizationConfig
7
7
 
8
+ from tpu_inference.layers.common import quant_methods
8
9
  from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
9
10
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
10
11
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
11
12
  VllmCompressedTensorsConfig # noqa: E501
13
+ from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
12
14
  from tpu_inference.layers.vllm.quantization.unquantized import \
13
15
  VllmUnquantizedConfig
14
16
 
@@ -19,8 +21,9 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
19
21
  # TODO(kyuyeunk): Add support for "tpu_int8".
20
22
  method_to_config: dict[str, str] = {
21
23
  None: VllmUnquantizedConfig,
22
- "compressed-tensors": VllmCompressedTensorsConfig,
23
- "awq": VllmAWQConfig,
24
+ quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
25
+ quant_methods.AWQ: VllmAWQConfig,
26
+ quant_methods.MXFP4: VllmMxfp4Config,
24
27
  }
25
28
  if model_config.quantization not in method_to_config:
26
29
  raise NotImplementedError(
@@ -30,6 +33,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
30
33
  assert issubclass(quant_config, JaxCommonConfig)
31
34
  quant_config.set_configs(vllm_config, mesh)
32
35
 
33
- model_config.quantization = quant_config.get_name()
36
+ model_config.quantization = quant_methods.get_tpu_quant_method(
37
+ quant_config.get_name())
34
38
  return VllmConfig.get_quantization_config(model_config,
35
39
  vllm_config.load_config)
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
18
  is_layer_skipped, unpack_quantized_values_into_int32)
19
19
  from vllm.scalar_type import scalar_types
20
20
 
21
+ from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
21
22
  from tpu_inference.layers.vllm.linear_common import (
22
23
  slice_sharded_tensor_for_concatenation, torch_to_jax_param)
23
24
  from tpu_inference.layers.vllm.quantization.common import (
@@ -29,12 +30,12 @@ P = PartitionSpec
29
30
  logger = init_logger(__name__)
30
31
 
31
32
 
32
- @register_quantization_config("jax-awq")
33
+ @register_quantization_config(get_tpu_quant_method(AWQ))
33
34
  class VllmAWQConfig(AWQConfig, JaxCommonConfig):
34
35
 
35
36
  @classmethod
36
- def get_name(cls) -> str:
37
- return "jax-awq"
37
+ def get_name(cls):
38
+ return AWQ
38
39
 
39
40
  def get_supported_act_dtypes(self) -> list[torch.dtype]:
40
41
  # NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
16
16
  from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
17
17
  find_matched_target, should_ignore_layer)
18
18
 
19
+ from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
20
+ get_tpu_quant_method)
19
21
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
20
22
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
21
23
  VllmCompressedTensorsW8A8Fp8MoEMethod
@@ -30,12 +32,12 @@ P = PartitionSpec
30
32
  logger = init_logger(__name__)
31
33
 
32
34
 
33
- @register_quantization_config("jax-compressed-tensors")
35
+ @register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
34
36
  class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
35
37
 
36
38
  @classmethod
37
39
  def get_name(cls) -> str:
38
- return "jax-compressed-tensors"
40
+ return COMPRESSED_TENSORS
39
41
 
40
42
  def get_scheme(self,
41
43
  layer: torch.nn.Module,