tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Utility functions for ragged paged attention."""
|
|
2
|
+
import jax
|
|
3
|
+
from jax._src import dtypes
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def cdiv(a, b):
|
|
7
|
+
assert b != 0
|
|
8
|
+
return (a + b - 1) // b
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def align_to(x, a):
|
|
12
|
+
return cdiv(x, a) * a
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_dtype_packing(dtype):
|
|
16
|
+
bits = dtypes.bit_width(dtype)
|
|
17
|
+
return 32 // bits
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def next_power_of_2(x: int):
|
|
21
|
+
"""Finds the smallest power of 2 >= x using bit manipulation.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
x: The input number (should be an integer).
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
The smallest integer power of 2 that is >= x.
|
|
28
|
+
"""
|
|
29
|
+
assert x > 0
|
|
30
|
+
if x == 1:
|
|
31
|
+
return 1
|
|
32
|
+
return 1 << (x - 1).bit_length()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_tpu_version() -> int:
|
|
36
|
+
"""Returns the numeric version of the TPU, or -1 if not on TPU."""
|
|
37
|
+
kind = jax.devices()[0].device_kind
|
|
38
|
+
if 'TPU' not in kind:
|
|
39
|
+
return -1
|
|
40
|
+
if kind.endswith(' lite'):
|
|
41
|
+
kind = kind[:-len(' lite')]
|
|
42
|
+
if kind.endswith('p'):
|
|
43
|
+
kind = kind[:-1]
|
|
44
|
+
if kind == 'TPU7x':
|
|
45
|
+
return 7
|
|
46
|
+
assert kind[:-1] == 'TPU v', kind
|
|
47
|
+
return int(kind[-1])
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@functools.partial(
|
|
9
|
+
jax.tree_util.register_dataclass,
|
|
10
|
+
data_fields=[
|
|
11
|
+
"input_positions",
|
|
12
|
+
"block_tables",
|
|
13
|
+
"seq_lens",
|
|
14
|
+
"query_start_loc",
|
|
15
|
+
"request_distribution",
|
|
16
|
+
],
|
|
17
|
+
meta_fields=[],
|
|
18
|
+
drop_fields=["query_start_loc_cpu", "seq_lens_cpu"],
|
|
19
|
+
)
|
|
20
|
+
@dataclass
|
|
21
|
+
class AttentionMetadata(object):
|
|
22
|
+
# (padded_total_num_scheduled_tokens,)
|
|
23
|
+
input_positions: jax.Array
|
|
24
|
+
# (max_num_seqs * max_num_blocks_per_req,)
|
|
25
|
+
block_tables: jax.Array = None
|
|
26
|
+
# (max_num_seqs,)
|
|
27
|
+
seq_lens: jax.Array = None
|
|
28
|
+
# (max_num_seqs + 1,)
|
|
29
|
+
query_start_loc: jax.Array = None
|
|
30
|
+
# (3,)
|
|
31
|
+
request_distribution: jax.Array = None
|
|
32
|
+
|
|
33
|
+
query_start_loc_cpu: Any = field(init=False)
|
|
34
|
+
seq_lens_cpu: Any = field(init=False)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from dataclasses import InitVar, dataclass
|
|
2
|
+
from typing import Any, Tuple
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from flax import nnx
|
|
7
|
+
from flax.typing import Sharding
|
|
8
|
+
from jax.experimental import shard_map
|
|
9
|
+
from jax.sharding import Mesh
|
|
10
|
+
from jax.sharding import PartitionSpec as P
|
|
11
|
+
|
|
12
|
+
from tpu_inference import utils
|
|
13
|
+
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
14
|
+
ragged_paged_attention
|
|
15
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
16
|
+
from tpu_inference.layers.jax.base import create_param
|
|
17
|
+
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
18
|
+
|
|
19
|
+
KVCache = Tuple[jax.Array, jax.Array]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(kw_only=True)
|
|
23
|
+
class Attention(nnx.Module):
|
|
24
|
+
"""An implementation of attention.
|
|
25
|
+
|
|
26
|
+
This module performs the attention mechanism for a transformer model,
|
|
27
|
+
including query, key, and value projections, application of Rotary
|
|
28
|
+
Position Embeddings (RoPE), and management of a KV cache for efficient
|
|
29
|
+
autoregressive generation. It supports both prefill and generation
|
|
30
|
+
(decode) modes and handles tensor sharding for distributed computation.
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
mesh: The JAX device mesh for distributed computation.
|
|
34
|
+
"""
|
|
35
|
+
hidden_size: int
|
|
36
|
+
num_attention_heads: int
|
|
37
|
+
num_key_value_heads: int
|
|
38
|
+
head_dim: int
|
|
39
|
+
rope_theta: float
|
|
40
|
+
rope_scaling: dict[str, Any]
|
|
41
|
+
dtype: jnp.dtype
|
|
42
|
+
mesh: Mesh
|
|
43
|
+
kv_cache_dtype: str
|
|
44
|
+
|
|
45
|
+
dnh_sharding: Sharding = ()
|
|
46
|
+
dkh_sharding: Sharding = ()
|
|
47
|
+
nhd_sharding: Sharding = ()
|
|
48
|
+
|
|
49
|
+
activation_q_td: Sharding = ()
|
|
50
|
+
query_tnh: P = P()
|
|
51
|
+
keyvalue_skh: P = P()
|
|
52
|
+
|
|
53
|
+
attn_o_tnh: P = P()
|
|
54
|
+
rngs: InitVar[nnx.Rngs]
|
|
55
|
+
|
|
56
|
+
random_init: bool = False
|
|
57
|
+
attention_chunk_size: int | None = None
|
|
58
|
+
rope_input_ordering: str = "split"
|
|
59
|
+
|
|
60
|
+
_q_scale: float = 1.0
|
|
61
|
+
_k_scale: float = 1.0
|
|
62
|
+
_v_scale: float = 1.0
|
|
63
|
+
|
|
64
|
+
kv_cache_quantized_dtype = None
|
|
65
|
+
|
|
66
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
67
|
+
"""Initializes the weight kernels for Q, K, V, and O projections."""
|
|
68
|
+
N = self.num_attention_heads
|
|
69
|
+
K = self.num_key_value_heads
|
|
70
|
+
D = self.hidden_size
|
|
71
|
+
H = self.head_dim
|
|
72
|
+
|
|
73
|
+
self.kernel_q_proj_DNH = create_param(rngs, (D, N, H),
|
|
74
|
+
self.dnh_sharding,
|
|
75
|
+
self.dtype,
|
|
76
|
+
random_init=self.random_init)
|
|
77
|
+
self.kernel_k_proj_DKH = create_param(rngs, (D, K, H),
|
|
78
|
+
self.dkh_sharding,
|
|
79
|
+
self.dtype,
|
|
80
|
+
random_init=self.random_init)
|
|
81
|
+
self.kernel_v_proj_DKH = create_param(rngs, (D, K, H),
|
|
82
|
+
self.dkh_sharding,
|
|
83
|
+
self.dtype,
|
|
84
|
+
random_init=self.random_init)
|
|
85
|
+
self.kernel_o_proj_NHD = create_param(rngs, (N, H, D),
|
|
86
|
+
self.nhd_sharding,
|
|
87
|
+
self.dtype,
|
|
88
|
+
random_init=self.random_init)
|
|
89
|
+
|
|
90
|
+
if self.kv_cache_dtype != "auto":
|
|
91
|
+
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
92
|
+
self.kv_cache_dtype)
|
|
93
|
+
|
|
94
|
+
def __call__(self,
|
|
95
|
+
x,
|
|
96
|
+
is_prefill,
|
|
97
|
+
kv_cache: KVCache,
|
|
98
|
+
attention_metadata: AttentionMetadata,
|
|
99
|
+
use_attention_rope: bool = True):
|
|
100
|
+
"""Performs the forward pass of the attention module.
|
|
101
|
+
|
|
102
|
+
This method computes the attention output by projecting the input `x`
|
|
103
|
+
to queries, keys, and values, applying RoPE, performing scaled
|
|
104
|
+
dot-product attention, and projecting the result back to the model
|
|
105
|
+
dimension. It updates and utilizes a KV cache.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
x: The input tensor of shape `(seq_len, d_model)`.
|
|
109
|
+
is_prefill: Whether the operation mode is prefill (otherwise it is generate).
|
|
110
|
+
kv_cache: The key-value cache for storing past attention states.
|
|
111
|
+
attention_metadata: Metadata for attention, such as input positions.
|
|
112
|
+
use_attention_rope: Whether to use RoPE.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
A tuple containing:
|
|
116
|
+
- The updated KV cache.
|
|
117
|
+
- The attention output tensor of shape
|
|
118
|
+
`(batch_size, seq_len, d_model)`.
|
|
119
|
+
"""
|
|
120
|
+
md = attention_metadata
|
|
121
|
+
x_SD = jnp.asarray(x, self.dtype)
|
|
122
|
+
x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
|
|
123
|
+
H = self.head_dim
|
|
124
|
+
with jax.named_scope("q_proj"):
|
|
125
|
+
q_TNH = jnp.einsum('TD,DNH -> TNH', x_q_TD,
|
|
126
|
+
self.kernel_q_proj_DNH.value)
|
|
127
|
+
if use_attention_rope:
|
|
128
|
+
q_TNH = apply_rope(q_TNH, md.input_positions, H,
|
|
129
|
+
self.rope_theta, self.rope_scaling,
|
|
130
|
+
self.rope_input_ordering)
|
|
131
|
+
q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
|
|
132
|
+
with jax.named_scope("k_proj"):
|
|
133
|
+
k_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
|
|
134
|
+
self.kernel_k_proj_DKH.value)
|
|
135
|
+
if use_attention_rope:
|
|
136
|
+
k_SKH = apply_rope(k_SKH, md.input_positions, H,
|
|
137
|
+
self.rope_theta, self.rope_scaling,
|
|
138
|
+
self.rope_input_ordering)
|
|
139
|
+
k_SKH = nnx.with_sharding_constraint(k_SKH, self.keyvalue_skh)
|
|
140
|
+
|
|
141
|
+
with jax.named_scope("v_proj"):
|
|
142
|
+
v_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
|
|
143
|
+
self.kernel_v_proj_DKH.value)
|
|
144
|
+
|
|
145
|
+
q_scale = k_scale = v_scale = None
|
|
146
|
+
if self.kv_cache_quantized_dtype:
|
|
147
|
+
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
148
|
+
# q_scale = self._q_scale
|
|
149
|
+
k_scale = self._k_scale
|
|
150
|
+
v_scale = self._v_scale
|
|
151
|
+
k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
|
|
152
|
+
self.kv_cache_quantized_dtype,
|
|
153
|
+
k_scale, v_scale)
|
|
154
|
+
|
|
155
|
+
with jax.named_scope("attn_op"):
|
|
156
|
+
new_kv_cache, outputs_TNH = self.attention(
|
|
157
|
+
is_prefill,
|
|
158
|
+
kv_cache,
|
|
159
|
+
q_TNH,
|
|
160
|
+
k_SKH,
|
|
161
|
+
v_SKH,
|
|
162
|
+
attention_metadata,
|
|
163
|
+
self.mesh,
|
|
164
|
+
q_scale=q_scale,
|
|
165
|
+
k_scale=k_scale,
|
|
166
|
+
v_scale=v_scale,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
with jax.named_scope("o_proj"):
|
|
170
|
+
o_TD = jnp.einsum('TNH,NHD -> TD', outputs_TNH,
|
|
171
|
+
self.kernel_o_proj_NHD.value)
|
|
172
|
+
return new_kv_cache, o_TD
|
|
173
|
+
|
|
174
|
+
def attention(
|
|
175
|
+
self,
|
|
176
|
+
is_prefill: bool,
|
|
177
|
+
kv_cache: KVCache,
|
|
178
|
+
q_TNH: jax.Array,
|
|
179
|
+
k_SKH: jax.Array,
|
|
180
|
+
v_SKH: jax.Array,
|
|
181
|
+
attention_metadata: AttentionMetadata,
|
|
182
|
+
mesh: Mesh,
|
|
183
|
+
q_scale: float | None = None,
|
|
184
|
+
k_scale: float | None = None,
|
|
185
|
+
v_scale: float | None = None,
|
|
186
|
+
) -> Tuple[KVCache, jax.Array]:
|
|
187
|
+
"""Performs scaled dot-product attention and updates the KV cache.
|
|
188
|
+
|
|
189
|
+
This function handles the core attention logic, which varies between
|
|
190
|
+
prefill and generation modes. In prefill, it computes self-attention
|
|
191
|
+
over the input sequence with a causal mask. In generation, it attends
|
|
192
|
+
to the full history of keys and values stored in the cache.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
is_prefill: A boolean indicating if the mode is 'prefill'.
|
|
196
|
+
kv_cache: The key-value cache to be updated and used.
|
|
197
|
+
q_TNH: Query tensor of shape `(query_seq, num_attention_heads, head_dim)`.
|
|
198
|
+
k_SKH: Key tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
|
|
199
|
+
v_SKH: Value tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
|
|
200
|
+
attention_metadata: Metadata containing sequence lengths.
|
|
201
|
+
mesh: The JAX device mesh (unused in this specific function but
|
|
202
|
+
kept for potential future use or API consistency).
|
|
203
|
+
q_scale: Quantization scale for q.
|
|
204
|
+
k_scale: Quantization scale for k.
|
|
205
|
+
v_scale: Quantization scale for v.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
A tuple containing:
|
|
209
|
+
- The updated KV cache.
|
|
210
|
+
- The attention output tensor of shape
|
|
211
|
+
`(seq, num_q_heads, head_dim)`.
|
|
212
|
+
"""
|
|
213
|
+
md = attention_metadata
|
|
214
|
+
kv_cache_spec = P(None, None, "model")
|
|
215
|
+
in_specs = (
|
|
216
|
+
self.query_tnh, # q
|
|
217
|
+
self.keyvalue_skh, # k
|
|
218
|
+
self.keyvalue_skh, # v
|
|
219
|
+
kv_cache_spec, # kv_cache
|
|
220
|
+
P(), # md.seq_lens: Replicated
|
|
221
|
+
P(), # page_indices_flat: Replicated
|
|
222
|
+
P(), # query_start_loc: Replicated
|
|
223
|
+
P(), # distribution: Replicated
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
out_specs = (self.attn_o_tnh, kv_cache_spec)
|
|
227
|
+
|
|
228
|
+
def _ragged_paged_attention(*args):
|
|
229
|
+
return ragged_paged_attention(
|
|
230
|
+
*args,
|
|
231
|
+
sm_scale=q_TNH.shape[-1]**-0.5,
|
|
232
|
+
q_scale=q_scale,
|
|
233
|
+
k_scale=k_scale,
|
|
234
|
+
v_scale=v_scale,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
output_TNH, kv_cache = jax.jit(
|
|
238
|
+
shard_map.shard_map(
|
|
239
|
+
_ragged_paged_attention,
|
|
240
|
+
mesh=mesh,
|
|
241
|
+
in_specs=in_specs,
|
|
242
|
+
out_specs=out_specs,
|
|
243
|
+
check_rep=False,
|
|
244
|
+
))(
|
|
245
|
+
q_TNH,
|
|
246
|
+
k_SKH,
|
|
247
|
+
v_SKH,
|
|
248
|
+
kv_cache,
|
|
249
|
+
md.seq_lens,
|
|
250
|
+
md.block_tables,
|
|
251
|
+
md.query_start_loc,
|
|
252
|
+
md.request_distribution,
|
|
253
|
+
)
|
|
254
|
+
return kv_cache, output_TNH
|
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import InitVar, dataclass
|
|
3
|
+
from typing import Any, Tuple
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from flax import nnx
|
|
8
|
+
from flax.typing import Sharding
|
|
9
|
+
from jax.experimental import shard_map
|
|
10
|
+
from jax.sharding import Mesh
|
|
11
|
+
from jax.sharding import PartitionSpec as P
|
|
12
|
+
|
|
13
|
+
from tpu_inference import utils
|
|
14
|
+
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
15
|
+
ragged_paged_attention
|
|
16
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
17
|
+
from tpu_inference.layers.jax.base import create_param
|
|
18
|
+
from tpu_inference.layers.jax.layers import RMSNorm
|
|
19
|
+
from tpu_inference.layers.jax.rope import DeepseekScalingRotaryEmbedding
|
|
20
|
+
|
|
21
|
+
KVCache = Tuple[jax.Array, jax.Array]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# TODO (wenxindongwork): Add MLA KV cache implementation. For now, cache complete KV vectors.
|
|
25
|
+
@dataclass(kw_only=True)
|
|
26
|
+
class MLA(nnx.Module):
|
|
27
|
+
"""An implementation of Multi-Head Latent Attention as
|
|
28
|
+
described in the DeepSeek V3 paper.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
mesh: The JAX device mesh for distributed computation.
|
|
32
|
+
"""
|
|
33
|
+
hidden_size: int
|
|
34
|
+
num_attention_heads: int
|
|
35
|
+
num_key_value_heads: int
|
|
36
|
+
head_dim: int
|
|
37
|
+
rope_theta: float
|
|
38
|
+
rope_scaling: dict[str, Any]
|
|
39
|
+
dtype: jnp.dtype
|
|
40
|
+
kv_cache_dtype: str
|
|
41
|
+
mesh: Mesh
|
|
42
|
+
|
|
43
|
+
q_lora_rank: int
|
|
44
|
+
kv_lora_rank: int
|
|
45
|
+
qk_nope_head_dim: int
|
|
46
|
+
qk_rope_head_dim: int
|
|
47
|
+
v_head_dim: int
|
|
48
|
+
rms_norm_eps: float
|
|
49
|
+
|
|
50
|
+
# Sharding attributes
|
|
51
|
+
nhd_sharding: Sharding = ()
|
|
52
|
+
q_da_sharding: Sharding = ()
|
|
53
|
+
anh_sharding: Sharding = ()
|
|
54
|
+
kv_da_sharding: Sharding = ()
|
|
55
|
+
|
|
56
|
+
activation_attention_td: Sharding = ()
|
|
57
|
+
activation_q_td: Sharding = ()
|
|
58
|
+
query_tnh: P = P()
|
|
59
|
+
keyvalue_skh: P = P()
|
|
60
|
+
|
|
61
|
+
attn_o_tnh: P = P()
|
|
62
|
+
activation_attention_out_td: Sharding = ()
|
|
63
|
+
|
|
64
|
+
random_init: bool = False
|
|
65
|
+
attention_chunk_size: int | None = None
|
|
66
|
+
rope_input_ordering: str = "split"
|
|
67
|
+
quant: Any | None = None
|
|
68
|
+
rope_mscale_all_dim: float = 1.0
|
|
69
|
+
|
|
70
|
+
rngs: InitVar[nnx.Rngs]
|
|
71
|
+
|
|
72
|
+
_q_scale: float = 1
|
|
73
|
+
_k_scale: float = 1
|
|
74
|
+
_v_scale: float = 1
|
|
75
|
+
|
|
76
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
77
|
+
self.N = self.num_attention_heads
|
|
78
|
+
self.K = self.num_key_value_heads
|
|
79
|
+
self.D = self.hidden_size
|
|
80
|
+
|
|
81
|
+
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
|
82
|
+
|
|
83
|
+
assert self.N == self.K, "N and K must be equal for MLA"
|
|
84
|
+
|
|
85
|
+
if self.rope_scaling["factor"] <= 1.0:
|
|
86
|
+
yarn_mscale = 1.0
|
|
87
|
+
else:
|
|
88
|
+
yarn_mscale = 0.1 * self.rope_mscale_all_dim * math.log(
|
|
89
|
+
self.rope_scaling["factor"]) + 1.0
|
|
90
|
+
self.scale = self.qk_head_dim**-0.5 * yarn_mscale**2
|
|
91
|
+
|
|
92
|
+
self.rope = DeepseekScalingRotaryEmbedding(
|
|
93
|
+
rotary_dim=self.qk_rope_head_dim,
|
|
94
|
+
rope_theta=self.rope_theta,
|
|
95
|
+
original_max_position_embeddings=self.
|
|
96
|
+
rope_scaling["original_max_position_embeddings"],
|
|
97
|
+
scaling_factor=self.rope_scaling["factor"],
|
|
98
|
+
dtype=self.dtype,
|
|
99
|
+
beta_fast=self.rope_scaling["beta_fast"],
|
|
100
|
+
beta_slow=self.rope_scaling["beta_slow"],
|
|
101
|
+
mscale_value=self.rope_scaling["mscale"],
|
|
102
|
+
mscale_all_dim=self.rope_scaling["mscale_all_dim"],
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Initializes the weight kernels
|
|
106
|
+
self.kernel_q_down_proj_DA = create_param(rngs,
|
|
107
|
+
(self.D, self.q_lora_rank),
|
|
108
|
+
self.q_da_sharding,
|
|
109
|
+
self.dtype,
|
|
110
|
+
random_init=self.random_init)
|
|
111
|
+
self.kernel_q_up_proj_ANH = create_param(
|
|
112
|
+
rngs,
|
|
113
|
+
(self.q_lora_rank, self.N, self.qk_head_dim),
|
|
114
|
+
self.anh_sharding,
|
|
115
|
+
self.dtype,
|
|
116
|
+
random_init=self.random_init,
|
|
117
|
+
)
|
|
118
|
+
self.kernel_kv_down_proj_DA = create_param(
|
|
119
|
+
rngs,
|
|
120
|
+
(self.D, self.kv_lora_rank + self.qk_rope_head_dim),
|
|
121
|
+
self.kv_da_sharding,
|
|
122
|
+
self.dtype,
|
|
123
|
+
random_init=self.random_init,
|
|
124
|
+
)
|
|
125
|
+
self.kernel_kv_up_proj_ANH = create_param(
|
|
126
|
+
rngs,
|
|
127
|
+
(self.kv_lora_rank, self.N,
|
|
128
|
+
self.qk_nope_head_dim + self.v_head_dim),
|
|
129
|
+
self.anh_sharding,
|
|
130
|
+
self.dtype,
|
|
131
|
+
random_init=self.random_init,
|
|
132
|
+
)
|
|
133
|
+
self.kernel_o_proj_NHD = create_param(
|
|
134
|
+
rngs, (self.N, self.v_head_dim, self.D),
|
|
135
|
+
self.nhd_sharding,
|
|
136
|
+
self.dtype,
|
|
137
|
+
random_init=self.random_init)
|
|
138
|
+
self.q_rms_norm = RMSNorm(
|
|
139
|
+
dims=self.q_lora_rank,
|
|
140
|
+
epsilon=self.rms_norm_eps,
|
|
141
|
+
with_scale=True,
|
|
142
|
+
dtype=self.dtype,
|
|
143
|
+
random_init=self.random_init,
|
|
144
|
+
rngs=rngs,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
self.kv_rms_norm = RMSNorm(
|
|
148
|
+
dims=self.kv_lora_rank,
|
|
149
|
+
random_init=self.random_init,
|
|
150
|
+
epsilon=self.rms_norm_eps,
|
|
151
|
+
with_scale=True,
|
|
152
|
+
dtype=self.dtype,
|
|
153
|
+
rngs=rngs,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self.kv_cache_quantized_dtype = None
|
|
157
|
+
if self.kv_cache_dtype != "auto":
|
|
158
|
+
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
159
|
+
self.kv_cache_dtype)
|
|
160
|
+
|
|
161
|
+
def __call__(self,
|
|
162
|
+
x,
|
|
163
|
+
is_prefill,
|
|
164
|
+
kv_cache: KVCache,
|
|
165
|
+
attention_metadata: AttentionMetadata,
|
|
166
|
+
use_attention_rope: bool = True):
|
|
167
|
+
"""Performs the forward pass of the attention module.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
x: The input tensor of shape `(batch_size, seq_len, d_model)`.
|
|
171
|
+
is_prefill: Whether the operation mode is prefill (otherwise it is generate).
|
|
172
|
+
kv_cache: The key-value cache for storing past attention states.
|
|
173
|
+
attention_metadata: Metadata for attention, such as input positions.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
A tuple containing:
|
|
177
|
+
- The updated KV cache.
|
|
178
|
+
- The attention output tensor of shape
|
|
179
|
+
`(batch_size, seq_len, d_model)`.
|
|
180
|
+
"""
|
|
181
|
+
md = attention_metadata
|
|
182
|
+
x = jnp.asarray(x, self.dtype)
|
|
183
|
+
x_SD = nnx.with_sharding_constraint(x, self.activation_attention_td)
|
|
184
|
+
x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
|
|
185
|
+
|
|
186
|
+
with jax.named_scope("q_proj"):
|
|
187
|
+
# Query down projection.
|
|
188
|
+
q_TA = jnp.einsum("TD,DA -> TA", x_q_TD,
|
|
189
|
+
self.kernel_q_down_proj_DA.value)
|
|
190
|
+
q_TA = self.q_rms_norm(q_TA)
|
|
191
|
+
# Query up projection.
|
|
192
|
+
q_TNH = jnp.einsum("TA,ANH -> TNH", q_TA,
|
|
193
|
+
self.kernel_q_up_proj_ANH.value)
|
|
194
|
+
# Split the query into nope and rope.
|
|
195
|
+
q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
|
|
196
|
+
q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
|
|
197
|
+
q_rope_TNH = self.rope.apply_rope(md.input_positions, q_rope_TNH)
|
|
198
|
+
# Concatenate the nope and rope queries.
|
|
199
|
+
q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
|
|
200
|
+
# Multiple the query by scaling factor
|
|
201
|
+
q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
|
|
202
|
+
|
|
203
|
+
with jax.named_scope("kv_proj"):
|
|
204
|
+
# KV down projection.
|
|
205
|
+
kv_SA = jnp.einsum("SD,DA -> SA", x_SD,
|
|
206
|
+
self.kernel_kv_down_proj_DA.value)
|
|
207
|
+
# Split the key and value into latent kv vector and k rope vector.
|
|
208
|
+
k_rope_SH = kv_SA[..., self.kv_lora_rank:]
|
|
209
|
+
# Reshape k_rope_BSH to include head dimension for RoPE application
|
|
210
|
+
k_rope_SNH = k_rope_SH[..., None, :]
|
|
211
|
+
k_rope_SNH = self.rope.apply_rope(md.input_positions, k_rope_SNH)
|
|
212
|
+
k_rope_SNH = jnp.broadcast_to(
|
|
213
|
+
k_rope_SNH,
|
|
214
|
+
(k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
|
|
215
|
+
kv_SA = kv_SA[..., :self.kv_lora_rank]
|
|
216
|
+
kv_SA = self.kv_rms_norm(kv_SA)
|
|
217
|
+
# KV up projection.
|
|
218
|
+
kv_nope_SNH = jnp.einsum("SA,ANH -> SNH", kv_SA,
|
|
219
|
+
self.kernel_kv_up_proj_ANH.value)
|
|
220
|
+
# Split the latent kv vector into k nope vector and v vector.
|
|
221
|
+
k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
|
|
222
|
+
v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
|
|
223
|
+
# Concatenate the key vector.
|
|
224
|
+
k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
|
|
225
|
+
k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
|
|
226
|
+
v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
|
|
227
|
+
|
|
228
|
+
with jax.named_scope("attn_op"):
|
|
229
|
+
# TODO(wenxindongwork): K and V have different head dimension,
|
|
230
|
+
# which is not supported by the current kv cache implementation.
|
|
231
|
+
# For now we are padding the v dimension to match the k dimension.
|
|
232
|
+
# Furthermore, deepseekv3 k head dimension is 192, which is
|
|
233
|
+
# not supported by the current attention kernel, which expects
|
|
234
|
+
# q, k, v head dimension to be multiple of 128. For now, we will
|
|
235
|
+
# pad the q, k, v dimension to multiple of 128.
|
|
236
|
+
# We should update the MLA kv cache implementation in the future.
|
|
237
|
+
multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
|
|
238
|
+
q_TNH = jnp.pad(q_TNH, ((0, 0), (0, 0),
|
|
239
|
+
(0, multiple_of_128 - self.qk_head_dim)))
|
|
240
|
+
k_SNH = jnp.pad(k_SNH, ((0, 0), (0, 0),
|
|
241
|
+
(0, multiple_of_128 - self.qk_head_dim)))
|
|
242
|
+
v_SNH = jnp.pad(v_SNH, ((0, 0), (0, 0),
|
|
243
|
+
(0, multiple_of_128 - self.v_head_dim)))
|
|
244
|
+
q_scale = k_scale = v_scale = None
|
|
245
|
+
if self.kv_cache_quantized_dtype:
|
|
246
|
+
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
247
|
+
# q_scale = self._q_scale
|
|
248
|
+
k_scale = self._k_scale
|
|
249
|
+
v_scale = self._v_scale
|
|
250
|
+
k_SNH, v_SNH = utils.quantize_kv(k_SNH, v_SNH,
|
|
251
|
+
self.kv_cache_quantized_dtype,
|
|
252
|
+
k_scale, v_scale)
|
|
253
|
+
new_kv_cache, outputs_TNH = self.attention(
|
|
254
|
+
is_prefill,
|
|
255
|
+
kv_cache,
|
|
256
|
+
q_TNH,
|
|
257
|
+
k_SNH,
|
|
258
|
+
v_SNH,
|
|
259
|
+
attention_metadata,
|
|
260
|
+
self.mesh,
|
|
261
|
+
q_scale,
|
|
262
|
+
k_scale,
|
|
263
|
+
v_scale,
|
|
264
|
+
)
|
|
265
|
+
# TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
|
|
266
|
+
# We shall add the MLA kv cache implementation in the future.
|
|
267
|
+
outputs_TNH = outputs_TNH[..., :self.v_head_dim]
|
|
268
|
+
|
|
269
|
+
with jax.named_scope("o_proj"):
|
|
270
|
+
o_TD = jnp.einsum("TNH,NHD -> TD", outputs_TNH,
|
|
271
|
+
self.kernel_o_proj_NHD.value)
|
|
272
|
+
o_TD = nnx.with_sharding_constraint(
|
|
273
|
+
o_TD, self.activation_attention_out_td)
|
|
274
|
+
return new_kv_cache, o_TD
|
|
275
|
+
|
|
276
|
+
def attention(
|
|
277
|
+
self,
|
|
278
|
+
is_prefill: bool,
|
|
279
|
+
kv_cache: KVCache,
|
|
280
|
+
q_TNH: jax.Array,
|
|
281
|
+
k_SKH: jax.Array,
|
|
282
|
+
v_SKH: jax.Array,
|
|
283
|
+
attention_metadata: AttentionMetadata,
|
|
284
|
+
mesh: Mesh,
|
|
285
|
+
q_scale: float | None = None,
|
|
286
|
+
k_scale: float | None = None,
|
|
287
|
+
v_scale: float | None = None,
|
|
288
|
+
) -> Tuple[KVCache, jax.Array]:
|
|
289
|
+
"""Performs scaled dot-product attention and updates the KV cache.
|
|
290
|
+
|
|
291
|
+
This function handles the core attention logic, which varies between
|
|
292
|
+
prefill and generation modes. In prefill, it computes self-attention
|
|
293
|
+
over the input sequence with a causal mask. In generation, it attends
|
|
294
|
+
to the full history of keys and values stored in the cache.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
is_prefill: A boolean indicating if the mode is 'prefill'.
|
|
298
|
+
kv_cache: The key-value cache to be updated and used.
|
|
299
|
+
q_TNH: Query tensor of shape `(query_seq, num_attention_heads, head_dim)`.
|
|
300
|
+
k_SKH: Key tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
|
|
301
|
+
v_SKH: Value tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
|
|
302
|
+
attention_metadata: Metadata containing sequence lengths.
|
|
303
|
+
mesh: The JAX device mesh (unused in this specific function but
|
|
304
|
+
kept for potential future use or API consistency).
|
|
305
|
+
q_scale: Quantization scale for q.
|
|
306
|
+
k_scale: Quantization scale for k.
|
|
307
|
+
v_scale: Quantization scale for v.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
A tuple containing:
|
|
311
|
+
- The updated KV cache.
|
|
312
|
+
- The attention output tensor of shape
|
|
313
|
+
`(seq, num_q_heads, head_dim)`.
|
|
314
|
+
"""
|
|
315
|
+
md = attention_metadata
|
|
316
|
+
in_specs = (
|
|
317
|
+
self.query_tnh, # q
|
|
318
|
+
self.keyvalue_skh, # k
|
|
319
|
+
self.keyvalue_skh, # v
|
|
320
|
+
P(None, None, "model"), # kv_cache
|
|
321
|
+
P(), # md.seq_lens: Replicated
|
|
322
|
+
P(), # page_indices_flat: Replicated
|
|
323
|
+
P(), # query_start_loc: Replicated
|
|
324
|
+
P(), # distribution: Replicated
|
|
325
|
+
)
|
|
326
|
+
out_specs = (self.attn_o_tnh, P(None, None, "model"))
|
|
327
|
+
|
|
328
|
+
def _ragged_paged_attention(*args):
|
|
329
|
+
return ragged_paged_attention(
|
|
330
|
+
*args,
|
|
331
|
+
sm_scale=self.scale,
|
|
332
|
+
q_scale=q_scale,
|
|
333
|
+
k_scale=k_scale,
|
|
334
|
+
v_scale=v_scale,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
output_TNH, kv_cache = jax.jit(
|
|
338
|
+
shard_map.shard_map(
|
|
339
|
+
_ragged_paged_attention,
|
|
340
|
+
mesh=mesh,
|
|
341
|
+
in_specs=in_specs,
|
|
342
|
+
out_specs=out_specs,
|
|
343
|
+
check_rep=False,
|
|
344
|
+
))(
|
|
345
|
+
q_TNH,
|
|
346
|
+
k_SKH,
|
|
347
|
+
v_SKH,
|
|
348
|
+
kv_cache,
|
|
349
|
+
md.seq_lens,
|
|
350
|
+
md.block_tables,
|
|
351
|
+
md.query_start_loc,
|
|
352
|
+
md.request_distribution,
|
|
353
|
+
)
|
|
354
|
+
return kv_cache, output_TNH
|