tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. tests/core/test_dp_scheduler.py +128 -71
  2. tests/e2e/test_data_parallel.py +176 -280
  3. tests/e2e/test_hybrid_kvcache.py +219 -0
  4. tests/e2e/test_speculative_decoding.py +26 -6
  5. tests/layers/jax/test_qwix.py +1 -1
  6. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
  7. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
  8. tests/layers/vllm/test_mxfp4.py +25 -10
  9. tests/layers/vllm/test_unquantized.py +61 -31
  10. tests/layers/vllm/utils.py +19 -4
  11. tests/models/common/test_model_loader.py +2 -2
  12. tests/models/jax/test_qwen2_5_vl.py +10 -11
  13. tests/runner/test_multimodal_manager.py +3 -3
  14. tests/runner/test_tpu_runner.py +67 -8
  15. tests/runner/test_tpu_runner_dp.py +66 -0
  16. tpu_inference/core/sched/dp_scheduler.py +65 -40
  17. tpu_inference/kernels/mla/v1/kernel.py +7 -26
  18. tpu_inference/layers/common/sharding.py +8 -3
  19. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
  20. tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
  21. tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
  22. tpu_inference/layers/jax/sample/sampling.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +51 -47
  24. tpu_inference/layers/vllm/quantization/common.py +14 -13
  25. tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
  26. tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
  27. tpu_inference/layers/vllm/sharding.py +7 -4
  28. tpu_inference/models/common/model_loader.py +11 -14
  29. tpu_inference/models/jax/llama3.py +13 -10
  30. tpu_inference/models/jax/llama_guard_4.py +1 -1
  31. tpu_inference/models/jax/qwen2.py +3 -2
  32. tpu_inference/models/jax/qwen2_5_vl.py +4 -4
  33. tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
  34. tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
  35. tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
  36. tpu_inference/platforms/tpu_platform.py +7 -7
  37. tpu_inference/runner/compilation_manager.py +43 -33
  38. tpu_inference/runner/kv_cache_manager.py +1 -2
  39. tpu_inference/runner/multimodal_manager.py +1 -1
  40. tpu_inference/runner/tpu_runner.py +12 -9
  41. tpu_inference/utils.py +31 -30
  42. tpu_inference/worker/tpu_worker.py +5 -2
  43. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
  44. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
  45. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
  46. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
  47. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ from collections import defaultdict, deque
18
18
  from dataclasses import dataclass
19
19
  from enum import Enum
20
20
  from multiprocessing import Process, Queue
21
+ from time import time
21
22
  from typing import Any, Dict, List, Optional, Tuple
22
23
 
23
24
  import cloudpickle
@@ -102,7 +103,7 @@ def _disable_cloudpickle():
102
103
  def _scheduler_worker_process(
103
104
  rank: int,
104
105
  input_queue: Queue,
105
- output_queue: Queue,
106
+ output_queues: Dict[str, Queue],
106
107
  vllm_config: Any,
107
108
  kv_cache_config: Any,
108
109
  structured_output_manager: Any,
@@ -135,55 +136,55 @@ def _scheduler_worker_process(
135
136
  case SchedulerCommand.ADD_REQUEST:
136
137
  request = data
137
138
  scheduler.add_request(request)
138
- output_queue.put(None) # Signal completion
139
+ output_queues[command.value].put(None) # Signal completion
139
140
 
140
141
  case SchedulerCommand.SCHEDULE:
141
142
  output = scheduler.schedule()
142
- output_queue.put(output)
143
+ output_queues[command.value].put(output)
143
144
 
144
145
  case SchedulerCommand.FINISH_REQUESTS:
145
146
  request_ids, finished_status = data
146
147
  scheduler.finish_requests(request_ids, finished_status)
147
- output_queue.put(None) # Signal completion
148
+ output_queues[command.value].put(None) # Signal completion
148
149
 
149
150
  case SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS:
150
151
  draft_token_ids = data
151
152
  scheduler.update_draft_token_ids(draft_token_ids)
152
- output_queue.put(None) # Signal completion
153
+ output_queues[command.value].put(None) # Signal completion
153
154
 
154
155
  case SchedulerCommand.UPDATE_FROM_OUTPUT:
155
156
  scheduler_output, model_runner_output = data
156
157
  result = scheduler.update_from_output(
157
158
  scheduler_output, model_runner_output)
158
- output_queue.put(result)
159
+ output_queues[command.value].put(result)
159
160
 
160
161
  case SchedulerCommand.GET_GRAMMAR_BITMASK:
161
162
  scheduler_output = data
162
163
  result = scheduler.get_grammar_bitmask(scheduler_output)
163
- output_queue.put(result)
164
+ output_queues[command.value].put(result)
164
165
 
165
166
  case SchedulerCommand.MAKE_STATS:
166
167
  spec_decoding_stats, kv_connector_stats = data
167
168
  result = scheduler.make_stats(spec_decoding_stats,
168
169
  kv_connector_stats)
169
- output_queue.put(result)
170
+ output_queues[command.value].put(result)
170
171
 
171
172
  case SchedulerCommand.RESET_PREFIX_CACHE:
172
173
  result = scheduler.reset_prefix_cache()
173
- output_queue.put(result)
174
+ output_queues[command.value].put(result)
174
175
 
175
176
  case SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS:
176
177
  result = scheduler.get_num_unfinished_requests()
177
- output_queue.put(result)
178
+ output_queues[command.value].put(result)
178
179
 
179
180
  case SchedulerCommand.HAS_FINISHED_REQUESTS:
180
181
  result = scheduler.has_finished_requests()
181
- output_queue.put(result)
182
+ output_queues[command.value].put(result)
182
183
 
183
184
  case SchedulerCommand.GET_REQUEST_COUNTS:
184
185
  running = len(scheduler.running)
185
186
  waiting = len(scheduler.waiting)
186
- output_queue.put((running, waiting))
187
+ output_queues[command.value].put((running, waiting))
187
188
 
188
189
  case SchedulerCommand.GET_TOKEN_COUNT:
189
190
  # Calculate total tokens across running and waiting requests
@@ -192,30 +193,29 @@ def _scheduler_worker_process(
192
193
  total_tokens += len(req.all_token_ids)
193
194
  for req in scheduler.waiting:
194
195
  total_tokens += len(req.all_token_ids)
195
- output_queue.put(total_tokens)
196
+ output_queues[command.value].put(total_tokens)
196
197
 
197
198
  case SchedulerCommand.GET_COMPUTED_BLOCKS:
198
199
  request = data
199
200
  blocks, cached_tokens = scheduler.kv_cache_manager.get_computed_blocks(
200
201
  request)
201
- output_queue.put((blocks, cached_tokens))
202
+ output_queues[command.value].put((blocks, cached_tokens))
202
203
 
203
204
  case SchedulerCommand.SHUTDOWN:
204
205
  scheduler.shutdown()
205
- output_queue.put(None) # Signal completion
206
+ output_queues[command.value].put(None) # Signal completion
206
207
  break
207
208
  case _:
208
209
  error = SchedulerWorkerError(
209
210
  rank, f"Unknown command: {command}")
210
- output_queue.put(error)
211
+ output_queues[command.value].put(error)
211
212
  raise error
212
213
 
213
214
  except Exception as e:
214
215
  logger.error(f"Error in scheduler worker {rank}: {e}",
215
216
  exc_info=True)
216
- # Put error on output queue
217
217
  error = SchedulerWorkerError(rank, str(e))
218
- output_queue.put(error)
218
+ output_queues[command.value].put(error)
219
219
 
220
220
 
221
221
  @dataclass
@@ -276,26 +276,29 @@ class DPScheduler(SchedulerInterface):
276
276
  # Enable cloudpickle for multiprocessing to handle local functions
277
277
  _enable_cloudpickle()
278
278
 
279
- # Create worker processes with one input and one output queue each
279
+ # Create worker processes with separate output queues for each command type
280
280
  import multiprocessing
281
281
  ctx = multiprocessing.get_context('fork')
282
282
  self.input_queues: List[Queue] = []
283
- self.output_queues: List[Queue] = []
283
+ self.output_queues: Dict[Tuple[int, str], Queue] = {}
284
284
  self.processes: List[Process] = []
285
285
 
286
286
  for rank in range(self.dp_size):
287
287
  input_queue = ctx.Queue()
288
- output_queue = ctx.Queue()
289
-
290
288
  self.input_queues.append(input_queue)
291
- self.output_queues.append(output_queue)
289
+
290
+ output_queues_for_rank: Dict[str, Queue] = {}
291
+ for cmd in SchedulerCommand:
292
+ output_queues_for_rank[cmd.value] = ctx.Queue()
293
+ self.output_queues[(
294
+ rank, cmd.value)] = output_queues_for_rank[cmd.value]
292
295
 
293
296
  process = ctx.Process(
294
297
  target=_scheduler_worker_process,
295
298
  args=(
296
299
  rank,
297
300
  input_queue,
298
- output_queue,
301
+ output_queues_for_rank,
299
302
  self.vllm_config,
300
303
  self.per_rank_kv_cache_configs[rank],
301
304
  structured_output_manager,
@@ -323,8 +326,24 @@ class DPScheduler(SchedulerInterface):
323
326
  rank_config.num_blocks = kv_cache_config.num_blocks // self.dp_size
324
327
  self.per_rank_kv_cache_configs.append(rank_config)
325
328
 
326
- def _get_result_from_queue(self, queue: Queue) -> Any:
327
- result = queue.get()
329
+ def _get_result_from_queue(self, rank: int,
330
+ command: SchedulerCommand) -> Any:
331
+ """Get result from the output queue for a specific rank and command type."""
332
+ queue_obj = self.output_queues[(rank, command.value)]
333
+ try:
334
+ start_time = time()
335
+ result = queue_obj.get()
336
+ end_time = time()
337
+ if end_time - start_time > 1.0:
338
+ logger.warning(
339
+ f"Long wait time ({end_time - start_time:.2f}s) for rank {rank} "
340
+ f"command {command.value} response.")
341
+ except EOFError as e:
342
+ raise RuntimeError(
343
+ f"Queue error for rank {rank} command {command.value}: "
344
+ "Worker process terminated unexpectedly. "
345
+ "This may indicate a crash in the scheduler worker process."
346
+ ) from e
328
347
  if isinstance(result, SchedulerWorkerError):
329
348
  raise result
330
349
  return result
@@ -337,7 +356,8 @@ class DPScheduler(SchedulerInterface):
337
356
 
338
357
  rank_tokens = {}
339
358
  for rank in range(self.dp_size):
340
- token_count = self._get_result_from_queue(self.output_queues[rank])
359
+ token_count = self._get_result_from_queue(
360
+ rank, SchedulerCommand.GET_TOKEN_COUNT)
341
361
  rank_tokens[rank] = token_count
342
362
 
343
363
  return rank_tokens
@@ -355,7 +375,7 @@ class DPScheduler(SchedulerInterface):
355
375
  best_cache_tokens = 0
356
376
  for rank in range(self.dp_size):
357
377
  blocks, cached_tokens = self._get_result_from_queue(
358
- self.output_queues[rank])
378
+ rank, SchedulerCommand.GET_COMPUTED_BLOCKS)
359
379
  if cached_tokens > best_cache_tokens:
360
380
  best_cache_tokens = cached_tokens
361
381
  best_cache_rank = rank
@@ -382,7 +402,7 @@ class DPScheduler(SchedulerInterface):
382
402
  self.assigned_dp_rank[request.request_id] = rank
383
403
 
384
404
  self.input_queues[rank].put((SchedulerCommand.ADD_REQUEST, request))
385
- self._get_result_from_queue(self.output_queues[rank])
405
+ self._get_result_from_queue(rank, SchedulerCommand.ADD_REQUEST)
386
406
 
387
407
  @time_function
388
408
  def schedule(self) -> DPSchedulerOutput:
@@ -402,7 +422,8 @@ class DPScheduler(SchedulerInterface):
402
422
  # Collect outputs from all workers (blocking)
403
423
  rank_outputs = []
404
424
  for rank in range(self.dp_size):
405
- output = self._get_result_from_queue(self.output_queues[rank])
425
+ output = self._get_result_from_queue(rank,
426
+ SchedulerCommand.SCHEDULE)
406
427
  rank_outputs.append(output)
407
428
 
408
429
  # Cache scheduler outputs to use in `update_from_output`
@@ -531,7 +552,7 @@ class DPScheduler(SchedulerInterface):
531
552
  rank_scheduler_outputs[rank]))
532
553
  for rank in range(self.dp_size):
533
554
  grammar_output = self._get_result_from_queue(
534
- self.output_queues[rank])
555
+ rank, SchedulerCommand.GET_GRAMMAR_BITMASK)
535
556
  if grammar_output is not None:
536
557
  combined_structured_output_request_ids.extend(
537
558
  grammar_output.structured_output_request_ids)
@@ -572,7 +593,7 @@ class DPScheduler(SchedulerInterface):
572
593
  combined_engine_outputs = defaultdict(list)
573
594
  for rank in range(self.dp_size):
574
595
  rank_engine_outputs = self._get_result_from_queue(
575
- self.output_queues[rank])
596
+ rank, SchedulerCommand.UPDATE_FROM_OUTPUT)
576
597
  for client_idx, engine_output in rank_engine_outputs.items():
577
598
  combined_engine_outputs[client_idx].append(engine_output)
578
599
 
@@ -640,7 +661,7 @@ class DPScheduler(SchedulerInterface):
640
661
  for rank, req_ids in rank_request_ids.items():
641
662
  self.input_queues[rank].put(
642
663
  (SchedulerCommand.FINISH_REQUESTS, (req_ids, finished_status)))
643
- self._get_result_from_queue(self.output_queues[rank])
664
+ self._get_result_from_queue(rank, SchedulerCommand.FINISH_REQUESTS)
644
665
 
645
666
  def get_num_unfinished_requests(self) -> int:
646
667
  """Get total number of unfinished requests across all DP ranks."""
@@ -650,7 +671,8 @@ class DPScheduler(SchedulerInterface):
650
671
 
651
672
  total = 0
652
673
  for rank in range(self.dp_size):
653
- count = self._get_result_from_queue(self.output_queues[rank])
674
+ count = self._get_result_from_queue(
675
+ rank, SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS)
654
676
  total += count
655
677
  return total
656
678
 
@@ -663,7 +685,7 @@ class DPScheduler(SchedulerInterface):
663
685
  has_finished_any = False
664
686
  for rank in range(self.dp_size):
665
687
  has_finished_any |= self._get_result_from_queue(
666
- self.output_queues[rank])
688
+ rank, SchedulerCommand.HAS_FINISHED_REQUESTS)
667
689
  return has_finished_any
668
690
 
669
691
  def get_request_counts(self) -> Tuple[int, int]:
@@ -676,7 +698,7 @@ class DPScheduler(SchedulerInterface):
676
698
  total_waiting = 0
677
699
  for rank in range(self.dp_size):
678
700
  running, waiting = self._get_result_from_queue(
679
- self.output_queues[rank])
701
+ rank, SchedulerCommand.GET_REQUEST_COUNTS)
680
702
  total_running += running
681
703
  total_waiting += waiting
682
704
  return total_running, total_waiting
@@ -689,7 +711,8 @@ class DPScheduler(SchedulerInterface):
689
711
 
690
712
  all_success = True
691
713
  for rank in range(self.dp_size):
692
- success = self._get_result_from_queue(self.output_queues[rank])
714
+ success = self._get_result_from_queue(
715
+ rank, SchedulerCommand.RESET_PREFIX_CACHE)
693
716
  all_success &= success
694
717
  return all_success
695
718
 
@@ -715,7 +738,8 @@ class DPScheduler(SchedulerInterface):
715
738
  kv_connector_stats)))
716
739
 
717
740
  for rank in range(self.dp_size):
718
- rank_stats = self._get_result_from_queue(self.output_queues[rank])
741
+ rank_stats = self._get_result_from_queue(
742
+ rank, SchedulerCommand.MAKE_STATS)
719
743
  if rank_stats is None:
720
744
  continue
721
745
 
@@ -776,7 +800,8 @@ class DPScheduler(SchedulerInterface):
776
800
  self.input_queues[rank].put(
777
801
  (SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS,
778
802
  rank_draft_token_ids))
779
- self._get_result_from_queue(self.output_queues[rank])
803
+ self._get_result_from_queue(
804
+ rank, SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS)
780
805
 
781
806
  def shutdown(self) -> None:
782
807
  """Shutdown all DP rank scheduler worker processes."""
@@ -786,7 +811,7 @@ class DPScheduler(SchedulerInterface):
786
811
 
787
812
  # Wait for acknowledgment (blocking)
788
813
  for rank in range(self.dp_size):
789
- self._get_result_from_queue(self.output_queues[rank])
814
+ self._get_result_from_queue(rank, SchedulerCommand.SHUTDOWN)
790
815
 
791
816
  # Terminate and join all processes
792
817
  for process in self.processes:
@@ -822,36 +822,17 @@ def _mla_ragged_paged_attention_kernel(
822
822
  return q_nope_vec, q_rope_vec
823
823
 
824
824
  def load_bkv(bkv_sem_idx, *, bkvc_mask, bkpe_mask):
825
- bitwidth = 32 // kv_packing
826
- repack_ty = jnp.dtype(f"uint{bitwidth}")
827
825
  bkvc_ref = (bkvc_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
828
826
  bkv_sz_per_kv_packing, lkv_dim))
829
- bkvc_vec = bkvc_ref[...]
830
- bkvc_vecs = []
831
- for i in range(kv_packing):
832
- masked_bkvc_vec = bkvc_vec >> (i * bitwidth)
833
- bkvc_vecs.append(masked_bkvc_vec)
834
- concated_bkvc_vec = jnp.concatenate(bkvc_vecs, axis=-1)
835
- concated_bkvc_vec = concated_bkvc_vec.reshape(bkv_sz, lkv_dim)
836
- concated_bkvc_vec = lax.select(bkvc_mask, concated_bkvc_vec,
837
- jnp.zeros_like(concated_bkvc_vec))
838
- concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
839
- kv_dtype)
827
+ bkvc_vec = pltpu.bitcast(bkvc_ref[...], kv_dtype)
828
+ bkvc_vec = lax.select(bkvc_mask, bkvc_vec, jnp.zeros_like(bkvc_vec))
829
+
840
830
  bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
841
831
  bkv_sz_per_kv_packing, r_dim))
842
- bkpe_vec = bkpe_ref[...]
843
- bkpe_vecs = []
844
- for i in range(kv_packing):
845
- masked_bkpe_vec = bkpe_vec >> (i * bitwidth)
846
- bkpe_vecs.append(masked_bkpe_vec)
847
- concated_bkpe_vec = jnp.concatenate(bkpe_vecs, axis=-1)
848
- concated_bkpe_vec = concated_bkpe_vec.reshape(bkv_sz, r_dim)
849
- concated_bkpe_vec = lax.select(bkpe_mask, concated_bkpe_vec,
850
- jnp.zeros_like(concated_bkpe_vec))
851
- concated_bkpe_vec = pltpu.bitcast(concated_bkpe_vec.astype(repack_ty),
852
- kv_dtype)
853
-
854
- return concated_bkvc_vec, concated_bkpe_vec
832
+ bkpe_vec = pltpu.bitcast(bkpe_ref[...], kv_dtype)
833
+ bkpe_vec = lax.select(bkpe_mask, bkpe_vec, jnp.zeros_like(bkpe_vec))
834
+
835
+ return bkvc_vec, bkpe_vec
855
836
 
856
837
  def broadcast_minor(src, shape):
857
838
  if src.shape == shape:
@@ -40,7 +40,7 @@ class ShardingAxisNameBase:
40
40
  MLP_TENSOR = ('attn_dp', 'model', 'expert')
41
41
  MOE_TENSOR = ('attn_dp', 'model')
42
42
  EXPERT = ('attn_dp', 'expert', 'model')
43
- VOCAB = ('expert', 'model')
43
+ VOCAB = ('expert', 'attn_dp', 'model')
44
44
 
45
45
 
46
46
  class ShardingAxisName2D:
@@ -141,6 +141,11 @@ class ShardingConfigManager:
141
141
  kv_dtype = utils.get_jax_dtype_from_str_dtype(
142
142
  cache_dtype) or jnp.bfloat16
143
143
  packing = 4 // jnp.dtype(kv_dtype).itemsize
144
+
145
+ # The default head dim is 128 but 64 is also supported as a special case.
146
+ if vllm_config.model_config.get_head_size() == 64:
147
+ packing *= 2
148
+
144
149
  # When num_kv_heads * 2 / packing < TP, tensor parallelism would
145
150
  # duplicate KV heads across devices, wasting kv cache memory.
146
151
  # Use attention DP instead to reduce per-device num_kv_heads and
@@ -186,8 +191,8 @@ class ShardingConfigManager:
186
191
  if sharding_strategy.attention_data_parallelism > 1:
187
192
  if not envs.NEW_MODEL_DESIGN:
188
193
  raise ValueError(
189
- "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
190
- "NEW_MODEL_DESIGN=True.")
194
+ "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set "
195
+ "NEW_MODEL_DESIGN=True")
191
196
 
192
197
  @property
193
198
  def total_dp_size(self) -> int:
@@ -30,6 +30,7 @@ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
30
30
  from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
31
31
  get_tuned_block_sizes
32
32
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
33
+ from tpu_inference.layers.common.quantization import quantize_kv
33
34
  from tpu_inference.layers.common.sharding import ShardingAxisName
34
35
  from tpu_inference.layers.jax.base import create_param
35
36
  from tpu_inference.layers.jax.layers import RMSNorm
@@ -310,9 +311,8 @@ class MLA(nnx.Module):
310
311
  # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
311
312
  k_scale = self._k_scale
312
313
  v_scale = self._v_scale
313
- k_SNH, v_SNH = utils.quantize_kv(
314
- k_SNH, v_SNH, self.kv_cache_quantized_dtype, k_scale,
315
- v_scale)
314
+ k_SNH, v_SNH = quantize_kv(self.kv_cache_quantized_dtype,
315
+ k_SNH, v_SNH, k_scale, v_scale)
316
316
 
317
317
  new_kv_cache, outputs_TNH = self.attention(
318
318
  is_prefill,
@@ -26,6 +26,7 @@ from tpu_inference import utils
26
26
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
27
27
  ragged_paged_attention_hd64
28
28
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
29
30
  from tpu_inference.layers.jax.base import create_param
30
31
  from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
31
32
 
@@ -248,9 +249,8 @@ class GptOssAttention(nnx.Module):
248
249
  # q_scale = self._q_scale
249
250
  k_scale = self._k_scale
250
251
  v_scale = self._v_scale
251
- k_TKH, v_TKH = utils.quantize_kv(k_TKH, v_TKH,
252
- self.kv_cache_quantized_dtype,
253
- k_scale, v_scale)
252
+ k_TKH, v_TKH = quantize_kv(self.kv_cache_quantized_dtype, k_TKH,
253
+ v_TKH, k_scale, v_scale)
254
254
 
255
255
  with jax.named_scope("attn_op"):
256
256
  new_kv_cache, attn_out_TNH = self.attention(
@@ -19,8 +19,8 @@ import jax.numpy as jnp
19
19
  from flax import nnx
20
20
  from jax.sharding import Sharding
21
21
 
22
- from tpu_inference import utils
23
22
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
23
+ from tpu_inference.layers.common.quantization import quantize_kv
24
24
  from tpu_inference.layers.jax.attention.attention import Attention, KVCache
25
25
  from tpu_inference.layers.jax.rope_interface import apply_rope
26
26
  from tpu_inference.logger import init_logger
@@ -128,9 +128,8 @@ class Llama4Attention(Attention):
128
128
  # q_scale = self._q_scale
129
129
  k_scale = self._k_scale
130
130
  v_scale = self._v_scale
131
- k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
132
- self.kv_cache_quantized_dtype,
133
- k_scale, v_scale)
131
+ k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
132
+ v_SKH, k_scale, v_scale)
134
133
 
135
134
  with jax.named_scope("attn_op"):
136
135
  new_kv_cache, outputs_TNH = self.attention(
@@ -42,7 +42,7 @@ def sample(
42
42
  if tpu_sampling_metadata.do_sampling:
43
43
  # Unshard the logits explicity to avoid latency increase.
44
44
  logits = jax.lax.with_sharding_constraint(
45
- logits, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
45
+ logits, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
46
46
  greedy_sampled = jnp.argmax(logits, axis=-1)
47
47
  if not tpu_sampling_metadata.do_sampling:
48
48
  return greedy_sampled
@@ -16,12 +16,14 @@ import functools
16
16
 
17
17
  import jax
18
18
  from jax import numpy as jnp
19
- from jax.sharding import Mesh
19
+ from jax.sharding import Mesh, NamedSharding
20
20
  from jax.sharding import PartitionSpec as P
21
21
 
22
22
  from tpu_inference.kernels.megablox.gmm import gmm
23
+ from tpu_inference.layers.common.sharding import ShardingAxisName
23
24
  from tpu_inference.layers.vllm.linear_common import \
24
25
  slice_sharded_tensor_for_concatenation
26
+ from tpu_inference.utils import get_mesh_shape_product
25
27
 
26
28
 
27
29
  def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
@@ -137,25 +139,23 @@ def tensor_sharded_gmm_merged_column_parallel(
137
139
  group_offset=jnp.array(0),
138
140
  )
139
141
 
140
- rhs_scale_spec = None if rhs_scale is None else P(None, None, None,
141
- "model")
142
- rhs_bias_spec = None if rhs_bias is None else P(None, None, "model")
142
+ rhs_scale_spec = None if rhs_scale is None else P(
143
+ None, None, None, ShardingAxisName.MLP_TENSOR)
144
+ rhs_bias_spec = None if rhs_bias is None else P(
145
+ None, None, ShardingAxisName.MLP_TENSOR)
143
146
 
144
147
  gmm_result = jax.shard_map(
145
148
  _gmm,
146
149
  mesh=mesh,
147
- in_specs=(
148
- P("data", None),
149
- P(None, "model", None),
150
- rhs_scale_spec,
151
- rhs_bias_spec,
152
- P("data"),
153
- ),
154
- out_specs=(P("data", "model")),
150
+ in_specs=(P(ShardingAxisName.MLP_DATA,
151
+ None), P(None, ShardingAxisName.MLP_TENSOR,
152
+ None), rhs_scale_spec, rhs_bias_spec,
153
+ P(ShardingAxisName.MLP_DATA)),
154
+ out_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR)),
155
155
  check_vma=False,
156
156
  )(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
157
157
 
158
- tp_size = mesh.shape["model"]
158
+ tp_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
159
159
  intermediate_size = gmm_result.shape[-1] // 2
160
160
  output_sizes = [intermediate_size, intermediate_size]
161
161
  return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
@@ -175,7 +175,7 @@ def tensor_sharded_gmm_row_parallel(
175
175
  m, g, n, k = lhs.shape[0], *rhs.shape
176
176
  tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
177
177
  if rhs_bias is not None:
178
- shard_id = jax.lax.axis_index("model")
178
+ shard_id = jax.lax.axis_index(ShardingAxisName.MLP_TENSOR).sum()
179
179
  rhs_bias = jnp.where(shard_id == 0, rhs_bias, 0)
180
180
  out = gmm(
181
181
  lhs,
@@ -188,22 +188,19 @@ def tensor_sharded_gmm_row_parallel(
188
188
  transpose_rhs=True,
189
189
  group_offset=jnp.array(0),
190
190
  )
191
- return jax.lax.psum(out, axis_name="model")
191
+ return jax.lax.psum(out, axis_name=ShardingAxisName.MLP_TENSOR)
192
192
 
193
193
  num_blocks = 1 if rhs_scale is None else rhs_scale.shape[1]
194
- rhs_scale_spec = None if num_blocks == 1 else P(None, "model", None, None)
194
+ rhs_scale_spec = None if num_blocks == 1 else P(
195
+ None, ShardingAxisName.MLP_TENSOR, None, None)
195
196
  rhs_bias_spec = None if rhs_bias is None else P(None, None, None)
196
197
  gmm_result = jax.shard_map(
197
198
  _gmm_all_reduce,
198
199
  mesh=mesh,
199
- in_specs=(
200
- P("data", "model"),
201
- P(None, None, "model"),
202
- rhs_scale_spec,
203
- rhs_bias_spec,
204
- P("data"),
205
- ),
206
- out_specs=(P("data")),
200
+ in_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR),
201
+ P(None, None, ShardingAxisName.MLP_TENSOR), rhs_scale_spec,
202
+ rhs_bias_spec, P(ShardingAxisName.MLP_DATA)),
203
+ out_specs=(P(ShardingAxisName.MLP_DATA)),
207
204
  check_vma=False,
208
205
  )(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
209
206
 
@@ -219,8 +216,8 @@ def expert_sharded_gmm(
219
216
  is_last_expert: bool,
220
217
  mesh: Mesh,
221
218
  ) -> jax.Array:
222
- ep_size = mesh.shape["model"]
223
-
219
+ ep_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
220
+ ep_p_spec = P(ShardingAxisName.EXPERT)
224
221
  num_experts = rhs.shape[0]
225
222
  num_experts_per_shard = num_experts // ep_size
226
223
  group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
@@ -260,21 +257,22 @@ def expert_sharded_gmm(
260
257
  # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
261
258
  # shard-0 shard-1 shard-2 shard-3
262
259
  # Each shards has 3 (row A), 2 (row B), 5 (row C) and 4 (row D).
263
- lhs_spec = P("model") if is_last_expert else P()
264
- rhs_scale_spec = None if rhs_scale is None else P("model")
265
- rhs_bias_spec = None if rhs_bias is None else P("model")
260
+ lhs_spec = ep_p_spec if is_last_expert else P()
261
+ rhs_spec = ep_p_spec
262
+ rhs_scale_spec = None if rhs_scale is None else ep_p_spec
263
+ rhs_bias_spec = None if rhs_bias is None else ep_p_spec
266
264
  gmm_res = jax.shard_map(
267
265
  _gmm,
268
266
  mesh=mesh,
269
267
  in_specs=(
270
268
  lhs_spec,
271
- P("model", None, None),
269
+ rhs_spec,
272
270
  rhs_scale_spec,
273
271
  rhs_bias_spec,
274
272
  P(),
275
- P("model"),
273
+ ep_p_spec,
276
274
  ),
277
- out_specs=(P("model", None)),
275
+ out_specs=ep_p_spec,
278
276
  check_vma=False,
279
277
  )(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset)
280
278
 
@@ -320,15 +318,13 @@ def expert_sharded_gmm(
320
318
  # send_sizes_of_shard [3, 3, 3, 3] [2, 2, 2, 2] [5, 5, 5, 5] [4, 4, 4, 4 ]
321
319
  # output_offsets_of_shard [0, 0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0] [10,10,10,10]
322
320
  # recv_sizes_of_shard [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4]
323
- return jax.lax.ragged_all_to_all(
324
- operand,
325
- output,
326
- input_offsets_of_shard,
327
- send_sizes_of_shard,
328
- output_offsets_of_shard,
329
- recv_sizes_of_shard,
330
- axis_name="model",
331
- )
321
+ return jax.lax.ragged_all_to_all(operand,
322
+ output,
323
+ input_offsets_of_shard,
324
+ send_sizes_of_shard,
325
+ output_offsets_of_shard,
326
+ recv_sizes_of_shard,
327
+ axis_name=ShardingAxisName.EXPERT)
332
328
 
333
329
  # Use ragged_all_to_all to send the result from gmm for each expert to all
334
330
  # the shards. In the working example, the result would be:
@@ -350,8 +346,8 @@ def expert_sharded_gmm(
350
346
  return jax.shard_map(
351
347
  _ragged_all_to_all,
352
348
  mesh=mesh,
353
- in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
354
- out_specs=(P()),
349
+ in_specs=(ep_p_spec, ep_p_spec, ep_p_spec, ep_p_spec, P()),
350
+ out_specs=(P(ShardingAxisName.MLP_DATA)),
355
351
  check_vma=False,
356
352
  )(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
357
353
 
@@ -412,6 +408,9 @@ def fused_moe_func(
412
408
  assert gating_output.shape == (num_tokens, global_num_experts)
413
409
 
414
410
  topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
411
+ # All-gather topk weights for attention dp
412
+ topk_weights = jax.lax.with_sharding_constraint(
413
+ topk_weights, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
415
414
  topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
416
415
  if renormalize:
417
416
  topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
@@ -434,8 +433,10 @@ def fused_moe_func(
434
433
  x, group_sizes, topk_argsort_revert_indices = jax.shard_map(
435
434
  _process_tokens_locally,
436
435
  mesh=mesh,
437
- in_specs=(P("data", None), P("data", None)),
438
- out_specs=(P("data", None), P("data"), P("data")),
436
+ in_specs=(P(ShardingAxisName.MLP_DATA,
437
+ None), P(ShardingAxisName.MLP_DATA, None)),
438
+ out_specs=(P(ShardingAxisName.MLP_DATA, None),
439
+ P(ShardingAxisName.MLP_DATA), P(ShardingAxisName.MLP_DATA)),
439
440
  )(hidden_states, topk_indices)
440
441
 
441
442
  x = jnp.pad(x, ((0, 0), (0, padded_hidden_size - hidden_size)))
@@ -495,8 +496,11 @@ def fused_moe_func(
495
496
  x = jax.shard_map(
496
497
  _finalize_output,
497
498
  mesh=mesh,
498
- in_specs=(P("data", None), P("data"), P("data", None)),
499
- out_specs=(P("data", None)),
499
+ in_specs=(P(ShardingAxisName.MLP_DATA,
500
+ None), P(ShardingAxisName.MLP_DATA),
501
+ P(ShardingAxisName.MLP_DATA, None)),
502
+ out_specs=(P(ShardingAxisName.ATTN_DATA, None)),
503
+ check_vma=False,
500
504
  )(x, topk_argsort_revert_indices, topk_weights)
501
505
 
502
506
  return x[:num_tokens, :hidden_size]