tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (58) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tests/test_envs.py +182 -0
  4. tests/test_utils.py +23 -14
  5. tpu_inference/__init__.py +22 -3
  6. tpu_inference/core/core_tpu.py +17 -9
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +2 -3
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +1 -1
  11. tpu_inference/executors/ray_distributed_executor.py +27 -11
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
  14. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
  15. tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
  16. tpu_inference/layers/common/quant_methods.py +8 -0
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  19. tpu_inference/layers/jax/sample/sampling.py +2 -2
  20. tpu_inference/layers/vllm/attention.py +1 -1
  21. tpu_inference/layers/vllm/quantization/__init__.py +7 -3
  22. tpu_inference/layers/vllm/quantization/awq.py +4 -3
  23. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
  24. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  25. tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
  26. tpu_inference/layers/vllm/sharding.py +2 -2
  27. tpu_inference/lora/torch_punica_tpu.py +1 -2
  28. tpu_inference/models/common/model_loader.py +12 -11
  29. tpu_inference/models/jax/llama3.py +4 -3
  30. tpu_inference/models/jax/llama_eagle3.py +9 -5
  31. tpu_inference/models/jax/llama_guard_4.py +361 -0
  32. tpu_inference/models/jax/qwen2.py +3 -2
  33. tpu_inference/models/jax/qwen2_5_vl.py +4 -3
  34. tpu_inference/models/jax/qwen3.py +3 -2
  35. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  36. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
  37. tpu_inference/platforms/tpu_platform.py +17 -7
  38. tpu_inference/runner/compilation_manager.py +37 -17
  39. tpu_inference/runner/kv_cache.py +1 -1
  40. tpu_inference/runner/kv_cache_manager.py +8 -2
  41. tpu_inference/runner/tpu_runner.py +199 -87
  42. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  43. tpu_inference/tpu_info.py +4 -3
  44. tpu_inference/utils.py +7 -6
  45. tpu_inference/worker/tpu_worker.py +159 -23
  46. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  47. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
  48. tpu_inference/mock/__init__.py +0 -0
  49. tpu_inference/mock/vllm_config_utils.py +0 -28
  50. tpu_inference/mock/vllm_envs.py +0 -1219
  51. tpu_inference/mock/vllm_logger.py +0 -212
  52. tpu_inference/mock/vllm_logging_utils.py +0 -15
  53. tpu_inference/models/jax/phi3.py +0 -376
  54. /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
  55. /tpu_inference/layers/{jax → common}/sharding.py +0 -0
  56. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  57. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  58. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
@@ -440,42 +440,54 @@ def _ragged_paged_attention_kernel(
440
440
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
441
441
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
442
442
 
443
- # Fetch effective kv from kv cache.
444
- def loop_body(i, offset):
445
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
446
- _async_copy(
447
- cache_hbm_ref.at[pl.ds(
448
- page_indices_ref[page_indices_offset + i] * page_size,
449
- sz)],
450
- vmem_ref.at[pl.ds(i * page_size, sz)],
451
- sem,
452
- wait,
443
+ if not wait:
444
+ # Fetch effective kv from kv cache.
445
+ def loop_body(i, offset):
446
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
447
+ _async_copy(
448
+ cache_hbm_ref.at[pl.ds(
449
+ page_indices_ref[page_indices_offset + i] * page_size,
450
+ sz)],
451
+ vmem_ref.at[pl.ds(i * page_size, sz)],
452
+ sem,
453
+ wait=False,
454
+ )
455
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
456
+ return offset + sz
457
+
458
+ offset = lax.fori_loop(
459
+ 0,
460
+ bkv_p_frm_cache,
461
+ loop_body,
462
+ 0, # offset
463
+ unroll=False,
453
464
  )
454
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
455
- return offset + sz
456
-
457
- offset = lax.fori_loop(
458
- 0,
459
- bkv_p_frm_cache,
460
- loop_body,
461
- 0, # offset
462
- unroll=False,
463
- )
464
465
 
465
- # Fetch kv directly from new kv.
466
- @pl.when(bkv_sz_frm_new > 0)
467
- def _fetch_bkv_from_new_kv():
468
- new_kv_len_start = q_end - kv_left_frm_new
469
- debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
470
- debug_print("[RPA debug] offset_in_bkv={}", offset)
466
+ # Fetch kv directly from new kv.
467
+ @pl.when(bkv_sz_frm_new > 0)
468
+ def _fetch_bkv_from_new_kv():
469
+ new_kv_len_start = q_end - kv_left_frm_new
470
+ debug_print("[RPA debug] new_kv_len_start={}",
471
+ new_kv_len_start)
472
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
473
+ _async_copy(
474
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476
+ sem,
477
+ wait,
478
+ )
479
+
480
+ return kv_len_start + offset, bkv_sz_frm_new
481
+ else:
482
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
483
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
471
484
  _async_copy(
472
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
473
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
474
- sem,
475
- wait,
485
+ src=dst,
486
+ dst=dst,
487
+ sem=sem,
488
+ wait=True,
476
489
  )
477
-
478
- return kv_len_start + offset, bkv_sz_frm_new
490
+ return kv_len_start + offset, bkv_sz_frm_new
479
491
 
480
492
  def _update_kv_cache(seq_idx,
481
493
  bkv_sem_idx,
@@ -511,30 +523,41 @@ def _ragged_paged_attention_kernel(
511
523
  debug_print("[RPA debug] p_ignore={}", p_ignore)
512
524
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
513
525
 
514
- def loop_body(i, states):
515
- update_sz, ignore = states
516
- sz = jnp.minimum(page_size - ignore, update_sz)
517
-
526
+ if not wait:
527
+
528
+ def loop_body(i, states):
529
+ update_sz, ignore = states
530
+ sz = jnp.minimum(page_size - ignore, update_sz)
531
+
532
+ _async_copy(
533
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
534
+ sz)],
535
+ cache_hbm_ref.at[pl.ds(
536
+ page_indices_ref[page_indices_offset + i] * page_size +
537
+ ignore,
538
+ sz,
539
+ )],
540
+ sem,
541
+ wait=False,
542
+ )
543
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
544
+ return update_sz - sz, 0
545
+
546
+ lax.fori_loop(
547
+ 0,
548
+ kv_p_end - kv_p_start,
549
+ loop_body,
550
+ (update_sz, ignore), # total transfer size
551
+ unroll=False,
552
+ )
553
+ else:
554
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
518
555
  _async_copy(
519
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
520
- cache_hbm_ref.at[pl.ds(
521
- page_indices_ref[page_indices_offset + i] * page_size +
522
- ignore,
523
- sz,
524
- )],
525
- sem,
526
- wait,
556
+ src=dst,
557
+ dst=dst,
558
+ sem=sem,
559
+ wait=True,
527
560
  )
528
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
529
- return update_sz - sz, 0
530
-
531
- lax.fori_loop(
532
- 0,
533
- kv_p_end - kv_p_start,
534
- loop_body,
535
- (update_sz, ignore), # total transfer size
536
- unroll=False,
537
- )
538
561
 
539
562
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
540
563
  sem = sems.at[1, bq_sem_idx]
@@ -317,6 +317,20 @@ def _ragged_paged_attention_kernel(
317
317
  q_len = q_end - q_start
318
318
  kv_len = kv_lens_ref[seq_idx]
319
319
 
320
+ bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
321
+ kv_len - sliding_window, 0) // bkv_sz
322
+
323
+ if sliding_window is None:
324
+ next_bkv_idx_start = 0
325
+ else:
326
+
327
+ def get_next_bkv_idx_start():
328
+ next_kv_len = kv_lens_ref[seq_idx + 1]
329
+ return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
330
+
331
+ next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
332
+ get_next_bkv_idx_start, lambda: 0)
333
+
320
334
  def debug_print(msg, *args):
321
335
  if debug_mode:
322
336
  pl.debug_print(msg, *args)
@@ -353,8 +367,8 @@ def _ragged_paged_attention_kernel(
353
367
  head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
354
368
 
355
369
  def load_with_init(ref, init_val):
356
- return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
357
- ref[...])
370
+ return jnp.where(bkv_idx == bkv_idx_start,
371
+ jnp.full_like(ref, init_val), ref[...])
358
372
 
359
373
  # Follow FlashAttention-2 forward pass.
360
374
  if q_scale is not None:
@@ -378,9 +392,6 @@ def _ragged_paged_attention_kernel(
378
392
  num_q_heads_per_kv_head)
379
393
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
380
394
  mask = q_span < k_span
381
- # TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
382
- if sliding_window is not None:
383
- mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
384
395
 
385
396
  if soft_cap is not None:
386
397
  s = soft_cap * jnp.tanh(s / soft_cap)
@@ -391,7 +402,8 @@ def _ragged_paged_attention_kernel(
391
402
  sinks = attention_sink_ref[kv_head_idx]
392
403
  actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head
393
404
  m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0)
394
- m_prev = jnp.where(bkv_idx == 0, m_prev_init, head_m_ref[...])
405
+ m_prev = jnp.where(bkv_idx == bkv_idx_start, m_prev_init,
406
+ head_m_ref[...])
395
407
  else:
396
408
  m_prev = load_with_init(head_m_ref, -jnp.inf)
397
409
 
@@ -463,42 +475,54 @@ def _ragged_paged_attention_kernel(
463
475
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
464
476
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
465
477
 
466
- # Fetch effective kv from kv cache.
467
- def loop_body(i, offset):
468
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
469
- _async_copy(
470
- cache_hbm_ref.at[pl.ds(
471
- page_indices_ref[page_indices_offset + i] * page_size,
472
- sz)],
473
- vmem_ref.at[pl.ds(i * page_size, sz)],
474
- sem,
475
- wait,
478
+ if not wait:
479
+ # Fetch effective kv from kv cache.
480
+ def loop_body(i, offset):
481
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
482
+ _async_copy(
483
+ cache_hbm_ref.at[pl.ds(
484
+ page_indices_ref[page_indices_offset + i] * page_size,
485
+ sz)],
486
+ vmem_ref.at[pl.ds(i * page_size, sz)],
487
+ sem,
488
+ wait=False,
489
+ )
490
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
491
+ return offset + sz
492
+
493
+ offset = lax.fori_loop(
494
+ 0,
495
+ bkv_p_frm_cache,
496
+ loop_body,
497
+ 0, # offset
498
+ unroll=False,
476
499
  )
477
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
478
- return offset + sz
479
-
480
- offset = lax.fori_loop(
481
- 0,
482
- bkv_p_frm_cache,
483
- loop_body,
484
- 0, # offset
485
- unroll=False,
486
- )
487
500
 
488
- # Fetch kv directly from new kv.
489
- @pl.when(bkv_sz_frm_new > 0)
490
- def _fetch_bkv_from_new_kv():
491
- new_kv_len_start = q_end - kv_left_frm_new
492
- debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
493
- debug_print("[RPA debug] offset_in_bkv={}", offset)
501
+ # Fetch kv directly from new kv.
502
+ @pl.when(bkv_sz_frm_new > 0)
503
+ def _fetch_bkv_from_new_kv():
504
+ new_kv_len_start = q_end - kv_left_frm_new
505
+ debug_print("[RPA debug] new_kv_len_start={}",
506
+ new_kv_len_start)
507
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
508
+ _async_copy(
509
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
510
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
511
+ sem,
512
+ wait,
513
+ )
514
+
515
+ return kv_len_start + offset, bkv_sz_frm_new
516
+ else:
517
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
518
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
494
519
  _async_copy(
495
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
496
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
497
- sem,
498
- wait,
520
+ src=dst,
521
+ dst=dst,
522
+ sem=sem,
523
+ wait=True,
499
524
  )
500
-
501
- return kv_len_start + offset, bkv_sz_frm_new
525
+ return kv_len_start + offset, bkv_sz_frm_new
502
526
 
503
527
  def _update_kv_cache(seq_idx,
504
528
  bkv_sem_idx,
@@ -534,30 +558,41 @@ def _ragged_paged_attention_kernel(
534
558
  debug_print("[RPA debug] p_ignore={}", p_ignore)
535
559
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
536
560
 
537
- def loop_body(i, states):
538
- update_sz, ignore = states
539
- sz = jnp.minimum(page_size - ignore, update_sz)
540
-
561
+ if not wait:
562
+
563
+ def loop_body(i, states):
564
+ update_sz, ignore = states
565
+ sz = jnp.minimum(page_size - ignore, update_sz)
566
+
567
+ _async_copy(
568
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
569
+ sz)],
570
+ cache_hbm_ref.at[pl.ds(
571
+ page_indices_ref[page_indices_offset + i] * page_size +
572
+ ignore,
573
+ sz,
574
+ )],
575
+ sem,
576
+ wait=False,
577
+ )
578
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
579
+ return update_sz - sz, 0
580
+
581
+ lax.fori_loop(
582
+ 0,
583
+ kv_p_end - kv_p_start,
584
+ loop_body,
585
+ (update_sz, ignore), # total transfer size
586
+ unroll=False,
587
+ )
588
+ else:
589
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
541
590
  _async_copy(
542
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
543
- cache_hbm_ref.at[pl.ds(
544
- page_indices_ref[page_indices_offset + i] * page_size +
545
- ignore,
546
- sz,
547
- )],
548
- sem,
549
- wait,
591
+ src=dst,
592
+ dst=dst,
593
+ sem=sem,
594
+ wait=True,
550
595
  )
551
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
552
- return update_sz - sz, 0
553
-
554
- lax.fori_loop(
555
- 0,
556
- kv_p_end - kv_p_start,
557
- loop_body,
558
- (update_sz, ignore), # total transfer size
559
- unroll=False,
560
- )
561
596
 
562
597
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
563
598
  sem = sems.at[1, bq_sem_idx]
@@ -719,12 +754,19 @@ def _ragged_paged_attention_kernel(
719
754
  def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
720
755
  next_bkv_idx = bkv_idx + 1
721
756
  is_last_bkv = next_bkv_idx == num_bkv
722
- next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
723
757
  next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
724
758
  is_last_bq = next_bq_idx == num_bq
725
759
  next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
726
760
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
727
761
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
762
+
763
+ next_bkv_idx = lax.select(
764
+ is_last_bkv,
765
+ lax.select(
766
+ is_last_bq,
767
+ next_bkv_idx_start,
768
+ bkv_idx_start,
769
+ ), next_bkv_idx)
728
770
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
729
771
 
730
772
  def compute_with_bq(bq_idx, _):
@@ -759,7 +801,7 @@ def _ragged_paged_attention_kernel(
759
801
  next_bkv_sem_idx)
760
802
 
761
803
  # Wait for cur bq if not ready yet
762
- @pl.when(bkv_idx == 0)
804
+ @pl.when(bkv_idx == bkv_idx_start)
763
805
  def wait_cur_bq():
764
806
  wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
765
807
 
@@ -808,7 +850,11 @@ def _ragged_paged_attention_kernel(
808
850
  kv_head_idx=kv_head_idx,
809
851
  )
810
852
 
811
- lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
853
+ lax.fori_loop(bkv_idx_start,
854
+ num_bkv,
855
+ compute_with_bkv,
856
+ None,
857
+ unroll=False)
812
858
 
813
859
  # Load acc and calculate final output.
814
860
  acc = acc_ref[...]
@@ -838,7 +884,7 @@ def _ragged_paged_attention_kernel(
838
884
  @pl.when(seq_idx == 0)
839
885
  def prologue():
840
886
  start_fetch_bq(0, 0, 0)
841
- start_fetch_bkv(0, 0, 0)
887
+ start_fetch_bkv(0, bkv_idx_start, 0)
842
888
 
843
889
  @pl.when(seq_idx < decode_end)
844
890
  def process_decode():
@@ -1231,6 +1231,13 @@ TUNED_BLOCK_SIZES = {
1231
1231
  },
1232
1232
  }
1233
1233
  },
1234
+ 16: {
1235
+ 'q_bfloat16_kv_bfloat16': {
1236
+ 'q_head-8_kv_head-1_head-128': {
1237
+ 262144: (128, 256),
1238
+ }
1239
+ }
1240
+ },
1234
1241
  },
1235
1242
  'TPU v5e': {
1236
1243
  128: {
@@ -17,7 +17,7 @@ import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
17
17
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
18
18
  from tpu_inference.kernels.flash_attention.kernel import flash_attention
19
19
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20
- from tpu_inference.layers.jax.sharding import ShardingAxisName
20
+ from tpu_inference.layers.common.sharding import ShardingAxisName
21
21
  from tpu_inference.utils import get_megacore
22
22
 
23
23
  MAX_ALLOWED_PAGE_INDICES_N = (
@@ -0,0 +1,8 @@
1
+ UNQUANTIZED = "unquantized"
2
+ MXFP4 = "mxfp4"
3
+ AWQ = "awq"
4
+ COMPRESSED_TENSORS = "compressed-tensors"
5
+
6
+
7
+ def get_tpu_quant_method(quant_method: str) -> str:
8
+ return "tpu-" + quant_method
@@ -13,9 +13,9 @@ from tpu_inference import utils
13
13
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
14
14
  ragged_paged_attention
15
15
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
16
+ from tpu_inference.layers.common.sharding import ShardingAxisName
16
17
  from tpu_inference.layers.jax.base import create_param
17
18
  from tpu_inference.layers.jax.rope_interface import apply_rope
18
- from tpu_inference.layers.jax.sharding import ShardingAxisName
19
19
 
20
20
  KVCache = Tuple[jax.Array, jax.Array]
21
21
 
@@ -12,7 +12,7 @@ import jax
12
12
  import jax.numpy as jnp
13
13
  import numpy as np
14
14
 
15
- from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
15
+ from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
16
16
  from tpu_inference.layers.jax.sample.sampling_metadata import \
17
17
  TPUSupportedSamplingMetadata
18
18
 
@@ -6,10 +6,10 @@ from jax.sharding import Mesh, NamedSharding
6
6
  from jax.sharding import PartitionSpec as P
7
7
  from vllm.v1.outputs import LogprobsTensors
8
8
 
9
- from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
9
+ from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
10
+ from tpu_inference.layers.common.sharding import ShardingAxisName
10
11
  from tpu_inference.layers.jax.sample.sampling_metadata import \
11
12
  TPUSupportedSamplingMetadata
12
- from tpu_inference.layers.jax.sharding import ShardingAxisName
13
13
 
14
14
  _SAMPLING_EPS = 1e-5
15
15
 
@@ -13,8 +13,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
13
  AttentionLayer, AttentionType)
14
14
 
15
15
  from tpu_inference import utils
16
+ from tpu_inference.layers.common.attention_interface import attention
16
17
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
17
- from tpu_inference.layers.jax.attention_interface import attention
18
18
  from tpu_inference.logger import init_logger
19
19
  from tpu_inference.models.vllm.vllm_model_wrapper_context import \
20
20
  get_vllm_model_wrapper_context
@@ -5,10 +5,12 @@ from vllm.config import VllmConfig
5
5
  from vllm.model_executor.layers.quantization.base_config import \
6
6
  QuantizationConfig
7
7
 
8
+ from tpu_inference.layers.common import quant_methods
8
9
  from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
9
10
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
10
11
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
11
12
  VllmCompressedTensorsConfig # noqa: E501
13
+ from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
12
14
  from tpu_inference.layers.vllm.quantization.unquantized import \
13
15
  VllmUnquantizedConfig
14
16
 
@@ -19,8 +21,9 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
19
21
  # TODO(kyuyeunk): Add support for "tpu_int8".
20
22
  method_to_config: dict[str, str] = {
21
23
  None: VllmUnquantizedConfig,
22
- "compressed-tensors": VllmCompressedTensorsConfig,
23
- "awq": VllmAWQConfig,
24
+ quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
25
+ quant_methods.AWQ: VllmAWQConfig,
26
+ quant_methods.MXFP4: VllmMxfp4Config,
24
27
  }
25
28
  if model_config.quantization not in method_to_config:
26
29
  raise NotImplementedError(
@@ -30,6 +33,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
30
33
  assert issubclass(quant_config, JaxCommonConfig)
31
34
  quant_config.set_configs(vllm_config, mesh)
32
35
 
33
- model_config.quantization = quant_config.get_name()
36
+ model_config.quantization = quant_methods.get_tpu_quant_method(
37
+ quant_config.get_name())
34
38
  return VllmConfig.get_quantization_config(model_config,
35
39
  vllm_config.load_config)
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
18
  is_layer_skipped, unpack_quantized_values_into_int32)
19
19
  from vllm.scalar_type import scalar_types
20
20
 
21
+ from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
21
22
  from tpu_inference.layers.vllm.linear_common import (
22
23
  slice_sharded_tensor_for_concatenation, torch_to_jax_param)
23
24
  from tpu_inference.layers.vllm.quantization.common import (
@@ -29,12 +30,12 @@ P = PartitionSpec
29
30
  logger = init_logger(__name__)
30
31
 
31
32
 
32
- @register_quantization_config("jax-awq")
33
+ @register_quantization_config(get_tpu_quant_method(AWQ))
33
34
  class VllmAWQConfig(AWQConfig, JaxCommonConfig):
34
35
 
35
36
  @classmethod
36
- def get_name(cls) -> str:
37
- return "jax-awq"
37
+ def get_name(cls):
38
+ return AWQ
38
39
 
39
40
  def get_supported_act_dtypes(self) -> list[torch.dtype]:
40
41
  # NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
16
16
  from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
17
17
  find_matched_target, should_ignore_layer)
18
18
 
19
+ from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
20
+ get_tpu_quant_method)
19
21
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
20
22
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
21
23
  VllmCompressedTensorsW8A8Fp8MoEMethod
@@ -30,12 +32,12 @@ P = PartitionSpec
30
32
  logger = init_logger(__name__)
31
33
 
32
34
 
33
- @register_quantization_config("jax-compressed-tensors")
35
+ @register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
34
36
  class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
35
37
 
36
38
  @classmethod
37
39
  def get_name(cls) -> str:
38
- return "jax-compressed-tensors"
40
+ return COMPRESSED_TENSORS
39
41
 
40
42
  def get_scheme(self,
41
43
  layer: torch.nn.Module,