tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +182 -0
- tests/test_utils.py +23 -14
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/core_tpu.py +17 -9
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
- tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/quantization/__init__.py +7 -3
- tpu_inference/layers/vllm/quantization/awq.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +12 -11
- tpu_inference/models/jax/llama3.py +4 -3
- tpu_inference/models/jax/llama_eagle3.py +9 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -3
- tpu_inference/models/jax/qwen3.py +3 -2
- tpu_inference/models/jax/utils/weight_utils.py +21 -8
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
- tpu_inference/platforms/tpu_platform.py +17 -7
- tpu_inference/runner/compilation_manager.py +37 -17
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +8 -2
- tpu_inference/runner/tpu_runner.py +199 -87
- tpu_inference/spec_decode/jax/eagle3.py +2 -1
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +7 -6
- tpu_inference/worker/tpu_worker.py +159 -23
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
- /tpu_inference/layers/{jax → common}/sharding.py +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
|
@@ -440,42 +440,54 @@ def _ragged_paged_attention_kernel(
|
|
|
440
440
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
441
441
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
442
442
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
443
|
+
if not wait:
|
|
444
|
+
# Fetch effective kv from kv cache.
|
|
445
|
+
def loop_body(i, offset):
|
|
446
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
447
|
+
_async_copy(
|
|
448
|
+
cache_hbm_ref.at[pl.ds(
|
|
449
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
450
|
+
sz)],
|
|
451
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
452
|
+
sem,
|
|
453
|
+
wait=False,
|
|
454
|
+
)
|
|
455
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
456
|
+
return offset + sz
|
|
457
|
+
|
|
458
|
+
offset = lax.fori_loop(
|
|
459
|
+
0,
|
|
460
|
+
bkv_p_frm_cache,
|
|
461
|
+
loop_body,
|
|
462
|
+
0, # offset
|
|
463
|
+
unroll=False,
|
|
453
464
|
)
|
|
454
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
455
|
-
return offset + sz
|
|
456
|
-
|
|
457
|
-
offset = lax.fori_loop(
|
|
458
|
-
0,
|
|
459
|
-
bkv_p_frm_cache,
|
|
460
|
-
loop_body,
|
|
461
|
-
0, # offset
|
|
462
|
-
unroll=False,
|
|
463
|
-
)
|
|
464
465
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
466
|
+
# Fetch kv directly from new kv.
|
|
467
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
468
|
+
def _fetch_bkv_from_new_kv():
|
|
469
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
470
|
+
debug_print("[RPA debug] new_kv_len_start={}",
|
|
471
|
+
new_kv_len_start)
|
|
472
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
473
|
+
_async_copy(
|
|
474
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
475
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
476
|
+
sem,
|
|
477
|
+
wait,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
481
|
+
else:
|
|
482
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
483
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
471
484
|
_async_copy(
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
sem,
|
|
475
|
-
wait,
|
|
485
|
+
src=dst,
|
|
486
|
+
dst=dst,
|
|
487
|
+
sem=sem,
|
|
488
|
+
wait=True,
|
|
476
489
|
)
|
|
477
|
-
|
|
478
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
490
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
479
491
|
|
|
480
492
|
def _update_kv_cache(seq_idx,
|
|
481
493
|
bkv_sem_idx,
|
|
@@ -511,30 +523,41 @@ def _ragged_paged_attention_kernel(
|
|
|
511
523
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
512
524
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
513
525
|
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
526
|
+
if not wait:
|
|
527
|
+
|
|
528
|
+
def loop_body(i, states):
|
|
529
|
+
update_sz, ignore = states
|
|
530
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
531
|
+
|
|
532
|
+
_async_copy(
|
|
533
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
534
|
+
sz)],
|
|
535
|
+
cache_hbm_ref.at[pl.ds(
|
|
536
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
537
|
+
ignore,
|
|
538
|
+
sz,
|
|
539
|
+
)],
|
|
540
|
+
sem,
|
|
541
|
+
wait=False,
|
|
542
|
+
)
|
|
543
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
544
|
+
return update_sz - sz, 0
|
|
545
|
+
|
|
546
|
+
lax.fori_loop(
|
|
547
|
+
0,
|
|
548
|
+
kv_p_end - kv_p_start,
|
|
549
|
+
loop_body,
|
|
550
|
+
(update_sz, ignore), # total transfer size
|
|
551
|
+
unroll=False,
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
518
555
|
_async_copy(
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
sz,
|
|
524
|
-
)],
|
|
525
|
-
sem,
|
|
526
|
-
wait,
|
|
556
|
+
src=dst,
|
|
557
|
+
dst=dst,
|
|
558
|
+
sem=sem,
|
|
559
|
+
wait=True,
|
|
527
560
|
)
|
|
528
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
529
|
-
return update_sz - sz, 0
|
|
530
|
-
|
|
531
|
-
lax.fori_loop(
|
|
532
|
-
0,
|
|
533
|
-
kv_p_end - kv_p_start,
|
|
534
|
-
loop_body,
|
|
535
|
-
(update_sz, ignore), # total transfer size
|
|
536
|
-
unroll=False,
|
|
537
|
-
)
|
|
538
561
|
|
|
539
562
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
540
563
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -317,6 +317,20 @@ def _ragged_paged_attention_kernel(
|
|
|
317
317
|
q_len = q_end - q_start
|
|
318
318
|
kv_len = kv_lens_ref[seq_idx]
|
|
319
319
|
|
|
320
|
+
bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
|
|
321
|
+
kv_len - sliding_window, 0) // bkv_sz
|
|
322
|
+
|
|
323
|
+
if sliding_window is None:
|
|
324
|
+
next_bkv_idx_start = 0
|
|
325
|
+
else:
|
|
326
|
+
|
|
327
|
+
def get_next_bkv_idx_start():
|
|
328
|
+
next_kv_len = kv_lens_ref[seq_idx + 1]
|
|
329
|
+
return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
|
|
330
|
+
|
|
331
|
+
next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
|
|
332
|
+
get_next_bkv_idx_start, lambda: 0)
|
|
333
|
+
|
|
320
334
|
def debug_print(msg, *args):
|
|
321
335
|
if debug_mode:
|
|
322
336
|
pl.debug_print(msg, *args)
|
|
@@ -353,8 +367,8 @@ def _ragged_paged_attention_kernel(
|
|
|
353
367
|
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
354
368
|
|
|
355
369
|
def load_with_init(ref, init_val):
|
|
356
|
-
return jnp.where(bkv_idx ==
|
|
357
|
-
ref[...])
|
|
370
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
371
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
358
372
|
|
|
359
373
|
# Follow FlashAttention-2 forward pass.
|
|
360
374
|
if q_scale is not None:
|
|
@@ -378,9 +392,6 @@ def _ragged_paged_attention_kernel(
|
|
|
378
392
|
num_q_heads_per_kv_head)
|
|
379
393
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
380
394
|
mask = q_span < k_span
|
|
381
|
-
# TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
|
|
382
|
-
if sliding_window is not None:
|
|
383
|
-
mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
|
|
384
395
|
|
|
385
396
|
if soft_cap is not None:
|
|
386
397
|
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
@@ -391,7 +402,8 @@ def _ragged_paged_attention_kernel(
|
|
|
391
402
|
sinks = attention_sink_ref[kv_head_idx]
|
|
392
403
|
actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head
|
|
393
404
|
m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0)
|
|
394
|
-
m_prev = jnp.where(bkv_idx ==
|
|
405
|
+
m_prev = jnp.where(bkv_idx == bkv_idx_start, m_prev_init,
|
|
406
|
+
head_m_ref[...])
|
|
395
407
|
else:
|
|
396
408
|
m_prev = load_with_init(head_m_ref, -jnp.inf)
|
|
397
409
|
|
|
@@ -463,42 +475,54 @@ def _ragged_paged_attention_kernel(
|
|
|
463
475
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
464
476
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
465
477
|
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
478
|
+
if not wait:
|
|
479
|
+
# Fetch effective kv from kv cache.
|
|
480
|
+
def loop_body(i, offset):
|
|
481
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
482
|
+
_async_copy(
|
|
483
|
+
cache_hbm_ref.at[pl.ds(
|
|
484
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
485
|
+
sz)],
|
|
486
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
487
|
+
sem,
|
|
488
|
+
wait=False,
|
|
489
|
+
)
|
|
490
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
491
|
+
return offset + sz
|
|
492
|
+
|
|
493
|
+
offset = lax.fori_loop(
|
|
494
|
+
0,
|
|
495
|
+
bkv_p_frm_cache,
|
|
496
|
+
loop_body,
|
|
497
|
+
0, # offset
|
|
498
|
+
unroll=False,
|
|
476
499
|
)
|
|
477
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
478
|
-
return offset + sz
|
|
479
|
-
|
|
480
|
-
offset = lax.fori_loop(
|
|
481
|
-
0,
|
|
482
|
-
bkv_p_frm_cache,
|
|
483
|
-
loop_body,
|
|
484
|
-
0, # offset
|
|
485
|
-
unroll=False,
|
|
486
|
-
)
|
|
487
500
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
501
|
+
# Fetch kv directly from new kv.
|
|
502
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
503
|
+
def _fetch_bkv_from_new_kv():
|
|
504
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
505
|
+
debug_print("[RPA debug] new_kv_len_start={}",
|
|
506
|
+
new_kv_len_start)
|
|
507
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
508
|
+
_async_copy(
|
|
509
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
510
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
511
|
+
sem,
|
|
512
|
+
wait,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
516
|
+
else:
|
|
517
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
518
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
494
519
|
_async_copy(
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
sem,
|
|
498
|
-
wait,
|
|
520
|
+
src=dst,
|
|
521
|
+
dst=dst,
|
|
522
|
+
sem=sem,
|
|
523
|
+
wait=True,
|
|
499
524
|
)
|
|
500
|
-
|
|
501
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
525
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
502
526
|
|
|
503
527
|
def _update_kv_cache(seq_idx,
|
|
504
528
|
bkv_sem_idx,
|
|
@@ -534,30 +558,41 @@ def _ragged_paged_attention_kernel(
|
|
|
534
558
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
535
559
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
536
560
|
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
561
|
+
if not wait:
|
|
562
|
+
|
|
563
|
+
def loop_body(i, states):
|
|
564
|
+
update_sz, ignore = states
|
|
565
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
566
|
+
|
|
567
|
+
_async_copy(
|
|
568
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
569
|
+
sz)],
|
|
570
|
+
cache_hbm_ref.at[pl.ds(
|
|
571
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
572
|
+
ignore,
|
|
573
|
+
sz,
|
|
574
|
+
)],
|
|
575
|
+
sem,
|
|
576
|
+
wait=False,
|
|
577
|
+
)
|
|
578
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
579
|
+
return update_sz - sz, 0
|
|
580
|
+
|
|
581
|
+
lax.fori_loop(
|
|
582
|
+
0,
|
|
583
|
+
kv_p_end - kv_p_start,
|
|
584
|
+
loop_body,
|
|
585
|
+
(update_sz, ignore), # total transfer size
|
|
586
|
+
unroll=False,
|
|
587
|
+
)
|
|
588
|
+
else:
|
|
589
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
541
590
|
_async_copy(
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
sz,
|
|
547
|
-
)],
|
|
548
|
-
sem,
|
|
549
|
-
wait,
|
|
591
|
+
src=dst,
|
|
592
|
+
dst=dst,
|
|
593
|
+
sem=sem,
|
|
594
|
+
wait=True,
|
|
550
595
|
)
|
|
551
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
552
|
-
return update_sz - sz, 0
|
|
553
|
-
|
|
554
|
-
lax.fori_loop(
|
|
555
|
-
0,
|
|
556
|
-
kv_p_end - kv_p_start,
|
|
557
|
-
loop_body,
|
|
558
|
-
(update_sz, ignore), # total transfer size
|
|
559
|
-
unroll=False,
|
|
560
|
-
)
|
|
561
596
|
|
|
562
597
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
563
598
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -719,12 +754,19 @@ def _ragged_paged_attention_kernel(
|
|
|
719
754
|
def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
|
|
720
755
|
next_bkv_idx = bkv_idx + 1
|
|
721
756
|
is_last_bkv = next_bkv_idx == num_bkv
|
|
722
|
-
next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
|
|
723
757
|
next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
|
|
724
758
|
is_last_bq = next_bq_idx == num_bq
|
|
725
759
|
next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
|
|
726
760
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
727
761
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
762
|
+
|
|
763
|
+
next_bkv_idx = lax.select(
|
|
764
|
+
is_last_bkv,
|
|
765
|
+
lax.select(
|
|
766
|
+
is_last_bq,
|
|
767
|
+
next_bkv_idx_start,
|
|
768
|
+
bkv_idx_start,
|
|
769
|
+
), next_bkv_idx)
|
|
728
770
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
729
771
|
|
|
730
772
|
def compute_with_bq(bq_idx, _):
|
|
@@ -759,7 +801,7 @@ def _ragged_paged_attention_kernel(
|
|
|
759
801
|
next_bkv_sem_idx)
|
|
760
802
|
|
|
761
803
|
# Wait for cur bq if not ready yet
|
|
762
|
-
@pl.when(bkv_idx ==
|
|
804
|
+
@pl.when(bkv_idx == bkv_idx_start)
|
|
763
805
|
def wait_cur_bq():
|
|
764
806
|
wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
|
|
765
807
|
|
|
@@ -808,7 +850,11 @@ def _ragged_paged_attention_kernel(
|
|
|
808
850
|
kv_head_idx=kv_head_idx,
|
|
809
851
|
)
|
|
810
852
|
|
|
811
|
-
lax.fori_loop(
|
|
853
|
+
lax.fori_loop(bkv_idx_start,
|
|
854
|
+
num_bkv,
|
|
855
|
+
compute_with_bkv,
|
|
856
|
+
None,
|
|
857
|
+
unroll=False)
|
|
812
858
|
|
|
813
859
|
# Load acc and calculate final output.
|
|
814
860
|
acc = acc_ref[...]
|
|
@@ -838,7 +884,7 @@ def _ragged_paged_attention_kernel(
|
|
|
838
884
|
@pl.when(seq_idx == 0)
|
|
839
885
|
def prologue():
|
|
840
886
|
start_fetch_bq(0, 0, 0)
|
|
841
|
-
start_fetch_bkv(0,
|
|
887
|
+
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
842
888
|
|
|
843
889
|
@pl.when(seq_idx < decode_end)
|
|
844
890
|
def process_decode():
|
|
@@ -17,7 +17,7 @@ import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
|
17
17
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
18
18
|
from tpu_inference.kernels.flash_attention.kernel import flash_attention
|
|
19
19
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
20
|
-
from tpu_inference.layers.
|
|
20
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
21
21
|
from tpu_inference.utils import get_megacore
|
|
22
22
|
|
|
23
23
|
MAX_ALLOWED_PAGE_INDICES_N = (
|
|
@@ -13,9 +13,9 @@ from tpu_inference import utils
|
|
|
13
13
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
14
14
|
ragged_paged_attention
|
|
15
15
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
16
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
16
17
|
from tpu_inference.layers.jax.base import create_param
|
|
17
18
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
18
|
-
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
19
19
|
|
|
20
20
|
KVCache = Tuple[jax.Array, jax.Array]
|
|
21
21
|
|
|
@@ -12,7 +12,7 @@ import jax
|
|
|
12
12
|
import jax.numpy as jnp
|
|
13
13
|
import numpy as np
|
|
14
14
|
|
|
15
|
-
from tpu_inference.layers.
|
|
15
|
+
from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
|
|
16
16
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
17
17
|
TPUSupportedSamplingMetadata
|
|
18
18
|
|
|
@@ -6,10 +6,10 @@ from jax.sharding import Mesh, NamedSharding
|
|
|
6
6
|
from jax.sharding import PartitionSpec as P
|
|
7
7
|
from vllm.v1.outputs import LogprobsTensors
|
|
8
8
|
|
|
9
|
-
from tpu_inference.layers.
|
|
9
|
+
from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
|
|
10
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
10
11
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
11
12
|
TPUSupportedSamplingMetadata
|
|
12
|
-
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
13
13
|
|
|
14
14
|
_SAMPLING_EPS = 1e-5
|
|
15
15
|
|
|
@@ -13,8 +13,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
|
13
13
|
AttentionLayer, AttentionType)
|
|
14
14
|
|
|
15
15
|
from tpu_inference import utils
|
|
16
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
16
17
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
17
|
-
from tpu_inference.layers.jax.attention_interface import attention
|
|
18
18
|
from tpu_inference.logger import init_logger
|
|
19
19
|
from tpu_inference.models.vllm.vllm_model_wrapper_context import \
|
|
20
20
|
get_vllm_model_wrapper_context
|
|
@@ -5,10 +5,12 @@ from vllm.config import VllmConfig
|
|
|
5
5
|
from vllm.model_executor.layers.quantization.base_config import \
|
|
6
6
|
QuantizationConfig
|
|
7
7
|
|
|
8
|
+
from tpu_inference.layers.common import quant_methods
|
|
8
9
|
from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
|
|
9
10
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
10
11
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
11
12
|
VllmCompressedTensorsConfig # noqa: E501
|
|
13
|
+
from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
|
|
12
14
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
13
15
|
VllmUnquantizedConfig
|
|
14
16
|
|
|
@@ -19,8 +21,9 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
|
19
21
|
# TODO(kyuyeunk): Add support for "tpu_int8".
|
|
20
22
|
method_to_config: dict[str, str] = {
|
|
21
23
|
None: VllmUnquantizedConfig,
|
|
22
|
-
|
|
23
|
-
|
|
24
|
+
quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
|
|
25
|
+
quant_methods.AWQ: VllmAWQConfig,
|
|
26
|
+
quant_methods.MXFP4: VllmMxfp4Config,
|
|
24
27
|
}
|
|
25
28
|
if model_config.quantization not in method_to_config:
|
|
26
29
|
raise NotImplementedError(
|
|
@@ -30,6 +33,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
|
30
33
|
assert issubclass(quant_config, JaxCommonConfig)
|
|
31
34
|
quant_config.set_configs(vllm_config, mesh)
|
|
32
35
|
|
|
33
|
-
model_config.quantization =
|
|
36
|
+
model_config.quantization = quant_methods.get_tpu_quant_method(
|
|
37
|
+
quant_config.get_name())
|
|
34
38
|
return VllmConfig.get_quantization_config(model_config,
|
|
35
39
|
vllm_config.load_config)
|
|
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
|
18
18
|
is_layer_skipped, unpack_quantized_values_into_int32)
|
|
19
19
|
from vllm.scalar_type import scalar_types
|
|
20
20
|
|
|
21
|
+
from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
|
|
21
22
|
from tpu_inference.layers.vllm.linear_common import (
|
|
22
23
|
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
23
24
|
from tpu_inference.layers.vllm.quantization.common import (
|
|
@@ -29,12 +30,12 @@ P = PartitionSpec
|
|
|
29
30
|
logger = init_logger(__name__)
|
|
30
31
|
|
|
31
32
|
|
|
32
|
-
@register_quantization_config(
|
|
33
|
+
@register_quantization_config(get_tpu_quant_method(AWQ))
|
|
33
34
|
class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
34
35
|
|
|
35
36
|
@classmethod
|
|
36
|
-
def get_name(cls)
|
|
37
|
-
return
|
|
37
|
+
def get_name(cls):
|
|
38
|
+
return AWQ
|
|
38
39
|
|
|
39
40
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
40
41
|
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
|
|
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|
|
16
16
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
17
17
|
find_matched_target, should_ignore_layer)
|
|
18
18
|
|
|
19
|
+
from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
|
|
20
|
+
get_tpu_quant_method)
|
|
19
21
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
20
22
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
21
23
|
VllmCompressedTensorsW8A8Fp8MoEMethod
|
|
@@ -30,12 +32,12 @@ P = PartitionSpec
|
|
|
30
32
|
logger = init_logger(__name__)
|
|
31
33
|
|
|
32
34
|
|
|
33
|
-
@register_quantization_config(
|
|
35
|
+
@register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
|
|
34
36
|
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
35
37
|
|
|
36
38
|
@classmethod
|
|
37
39
|
def get_name(cls) -> str:
|
|
38
|
-
return
|
|
40
|
+
return COMPRESSED_TENSORS
|
|
39
41
|
|
|
40
42
|
def get_scheme(self,
|
|
41
43
|
layer: torch.nn.Module,
|