tpu-inference 0.11.1rc2__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +1 -1
  49. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,60 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """All-gather matmul kernel's tuned block sizes."""
3
+
4
+ import jax
5
+
6
+ # key:
7
+ # - tpu_version
8
+ # - m
9
+ # - n
10
+ # - k
11
+ # - dtype
12
+ # - tp_size
13
+ # value:
14
+ # - bn
15
+ # - bk
16
+ TUNED_BLOCK_SIZES = {
17
+ # go/keep-sorted start
18
+ (6, 1024, 51200, 5120, 'bfloat16', 8): (6400, 2560),
19
+ (6, 1024, 57344, 8192, 'bfloat16', 8): (7168, 8192),
20
+ (6, 2048, 51200, 5120, 'bfloat16', 8): (1280, 5120),
21
+ (6, 2048, 57344, 8192, 'bfloat16', 8): (1024, 8192),
22
+ (6, 4096, 51200, 5120, 'bfloat16', 8): (3200, 5120),
23
+ (6, 8192, 51200, 5120, 'bfloat16', 8): (1280, 5120),
24
+ # go/keep-sorted end
25
+ }
26
+
27
+
28
+ def get_tpu_version() -> int:
29
+ """Returns the numeric version of the TPU, or -1 if not on TPU."""
30
+ kind = jax.devices()[0].device_kind
31
+ if 'TPU' not in kind:
32
+ return -1
33
+ if kind.endswith(' lite'):
34
+ kind = kind[:-len(' lite')]
35
+ assert kind[:-1] == 'TPU v', kind
36
+ return int(kind[-1])
37
+
38
+
39
+ def get_key(
40
+ m,
41
+ n,
42
+ k,
43
+ dtype,
44
+ tp_size,
45
+ ):
46
+ """Returns the key for the given parameters."""
47
+ return (
48
+ get_tpu_version(),
49
+ m,
50
+ n,
51
+ k,
52
+ dtype,
53
+ tp_size,
54
+ )
55
+
56
+
57
+ def get_tuned_block_sizes(m, n, k, dtype_name, tp_size):
58
+ """Returns the tuned block sizes for the given parameters."""
59
+ key = get_key(m, n, k, dtype_name, tp_size)
60
+ return TUNED_BLOCK_SIZES.get(key, (None, None))
@@ -0,0 +1,47 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """utilities for collective kernels."""
3
+
4
+ import functools
5
+
6
+ from jax.experimental import pallas as pl
7
+ from jax.experimental.pallas import tpu as pltpu
8
+
9
+
10
+ def local_barrier(left_neighbor, right_neighbor, double_barrier=True):
11
+ """Performs a barrier with neighbors on the global barrier semaphore.
12
+
13
+ Optionally performs a second barrier, which prevents a potential race
14
+ when reusing the same collective_id across kernel invocations.
15
+
16
+ Args:
17
+ left_neighbor: Left neighbor device id.
18
+ right_neighbor: Right neighbor device id.
19
+ double_barrier: Whether to perform a second barrier.
20
+ """
21
+ barrier_sem = pltpu.get_barrier_semaphore()
22
+ for neighbor in [left_neighbor, right_neighbor]:
23
+ pltpu.semaphore_signal(
24
+ barrier_sem,
25
+ inc=1,
26
+ device_id=(neighbor, ),
27
+ device_id_type=pltpu.DeviceIdType.MESH,
28
+ )
29
+ pltpu.semaphore_wait(barrier_sem, 2)
30
+ if double_barrier:
31
+ # The double-barrier prevents a race condition where one neighbor can
32
+ # re-enter the kernel again on a subsequent call and increment the
33
+ # barrier semaphore a second time. This would unblock the current device
34
+ # even if the other neighbor is not ready yet.
35
+ # To implement a double-barrier, we stack-allocate a second REGULAR
36
+ # semaphore using run_scoped.
37
+ @functools.partial(pl.run_scoped,
38
+ second_barrier=pltpu.SemaphoreType.REGULAR)
39
+ def _(second_barrier):
40
+ for neighbor in [left_neighbor, right_neighbor]:
41
+ pltpu.semaphore_signal(
42
+ second_barrier,
43
+ inc=1,
44
+ device_id=(neighbor, ),
45
+ device_id_type=pltpu.DeviceIdType.MESH,
46
+ )
47
+ pltpu.semaphore_wait(second_barrier, 2)
File without changes
File without changes
@@ -0,0 +1,34 @@
1
+ import functools
2
+ from dataclasses import dataclass, field
3
+ from typing import Any
4
+
5
+ import jax
6
+
7
+
8
+ @functools.partial(
9
+ jax.tree_util.register_dataclass,
10
+ data_fields=[
11
+ "input_positions",
12
+ "block_tables",
13
+ "seq_lens",
14
+ "query_start_loc",
15
+ "request_distribution",
16
+ ],
17
+ meta_fields=[],
18
+ drop_fields=["query_start_loc_cpu", "seq_lens_cpu"],
19
+ )
20
+ @dataclass
21
+ class AttentionMetadata(object):
22
+ # (padded_total_num_scheduled_tokens,)
23
+ input_positions: jax.Array
24
+ # (max_num_seqs * max_num_blocks_per_req,)
25
+ block_tables: jax.Array = None
26
+ # (max_num_seqs,)
27
+ seq_lens: jax.Array = None
28
+ # (max_num_seqs + 1,)
29
+ query_start_loc: jax.Array = None
30
+ # (3,)
31
+ request_distribution: jax.Array = None
32
+
33
+ query_start_loc_cpu: Any = field(init=False)
34
+ seq_lens_cpu: Any = field(init=False)
File without changes
File without changes
@@ -0,0 +1,254 @@
1
+ from dataclasses import InitVar, dataclass
2
+ from typing import Any, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax import nnx
7
+ from flax.typing import Sharding
8
+ from jax.experimental import shard_map
9
+ from jax.sharding import Mesh
10
+ from jax.sharding import PartitionSpec as P
11
+
12
+ from tpu_inference import utils
13
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
14
+ ragged_paged_attention
15
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
16
+ from tpu_inference.layers.jax.base import create_param
17
+ from tpu_inference.layers.jax.rope_interface import apply_rope
18
+
19
+ KVCache = Tuple[jax.Array, jax.Array]
20
+
21
+
22
+ @dataclass(kw_only=True)
23
+ class Attention(nnx.Module):
24
+ """An implementation of attention.
25
+
26
+ This module performs the attention mechanism for a transformer model,
27
+ including query, key, and value projections, application of Rotary
28
+ Position Embeddings (RoPE), and management of a KV cache for efficient
29
+ autoregressive generation. It supports both prefill and generation
30
+ (decode) modes and handles tensor sharding for distributed computation.
31
+
32
+ Attributes:
33
+ mesh: The JAX device mesh for distributed computation.
34
+ """
35
+ hidden_size: int
36
+ num_attention_heads: int
37
+ num_key_value_heads: int
38
+ head_dim: int
39
+ rope_theta: float
40
+ rope_scaling: dict[str, Any]
41
+ dtype: jnp.dtype
42
+ mesh: Mesh
43
+ kv_cache_dtype: str
44
+
45
+ dnh_sharding: Sharding = ()
46
+ dkh_sharding: Sharding = ()
47
+ nhd_sharding: Sharding = ()
48
+
49
+ activation_q_td: Sharding = ()
50
+ query_tnh: P = P()
51
+ keyvalue_skh: P = P()
52
+
53
+ attn_o_tnh: P = P()
54
+ rngs: InitVar[nnx.Rngs]
55
+
56
+ random_init: bool = False
57
+ attention_chunk_size: int | None = None
58
+ rope_input_ordering: str = "split"
59
+
60
+ _q_scale: float = 1.0
61
+ _k_scale: float = 1.0
62
+ _v_scale: float = 1.0
63
+
64
+ kv_cache_quantized_dtype = None
65
+
66
+ def __post_init__(self, rngs: nnx.Rngs):
67
+ """Initializes the weight kernels for Q, K, V, and O projections."""
68
+ N = self.num_attention_heads
69
+ K = self.num_key_value_heads
70
+ D = self.hidden_size
71
+ H = self.head_dim
72
+
73
+ self.kernel_q_proj_DNH = create_param(rngs, (D, N, H),
74
+ self.dnh_sharding,
75
+ self.dtype,
76
+ random_init=self.random_init)
77
+ self.kernel_k_proj_DKH = create_param(rngs, (D, K, H),
78
+ self.dkh_sharding,
79
+ self.dtype,
80
+ random_init=self.random_init)
81
+ self.kernel_v_proj_DKH = create_param(rngs, (D, K, H),
82
+ self.dkh_sharding,
83
+ self.dtype,
84
+ random_init=self.random_init)
85
+ self.kernel_o_proj_NHD = create_param(rngs, (N, H, D),
86
+ self.nhd_sharding,
87
+ self.dtype,
88
+ random_init=self.random_init)
89
+
90
+ if self.kv_cache_dtype != "auto":
91
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
92
+ self.kv_cache_dtype)
93
+
94
+ def __call__(self,
95
+ x,
96
+ is_prefill,
97
+ kv_cache: KVCache,
98
+ attention_metadata: AttentionMetadata,
99
+ use_attention_rope: bool = True):
100
+ """Performs the forward pass of the attention module.
101
+
102
+ This method computes the attention output by projecting the input `x`
103
+ to queries, keys, and values, applying RoPE, performing scaled
104
+ dot-product attention, and projecting the result back to the model
105
+ dimension. It updates and utilizes a KV cache.
106
+
107
+ Args:
108
+ x: The input tensor of shape `(seq_len, d_model)`.
109
+ is_prefill: Whether the operation mode is prefill (otherwise it is generate).
110
+ kv_cache: The key-value cache for storing past attention states.
111
+ attention_metadata: Metadata for attention, such as input positions.
112
+ use_attention_rope: Whether to use RoPE.
113
+
114
+ Returns:
115
+ A tuple containing:
116
+ - The updated KV cache.
117
+ - The attention output tensor of shape
118
+ `(batch_size, seq_len, d_model)`.
119
+ """
120
+ md = attention_metadata
121
+ x_SD = jnp.asarray(x, self.dtype)
122
+ x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
123
+ H = self.head_dim
124
+ with jax.named_scope("q_proj"):
125
+ q_TNH = jnp.einsum('TD,DNH -> TNH', x_q_TD,
126
+ self.kernel_q_proj_DNH.value)
127
+ if use_attention_rope:
128
+ q_TNH = apply_rope(q_TNH, md.input_positions, H,
129
+ self.rope_theta, self.rope_scaling,
130
+ self.rope_input_ordering)
131
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
132
+ with jax.named_scope("k_proj"):
133
+ k_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
134
+ self.kernel_k_proj_DKH.value)
135
+ if use_attention_rope:
136
+ k_SKH = apply_rope(k_SKH, md.input_positions, H,
137
+ self.rope_theta, self.rope_scaling,
138
+ self.rope_input_ordering)
139
+ k_SKH = nnx.with_sharding_constraint(k_SKH, self.keyvalue_skh)
140
+
141
+ with jax.named_scope("v_proj"):
142
+ v_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
143
+ self.kernel_v_proj_DKH.value)
144
+
145
+ q_scale = k_scale = v_scale = None
146
+ if self.kv_cache_quantized_dtype:
147
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
148
+ # q_scale = self._q_scale
149
+ k_scale = self._k_scale
150
+ v_scale = self._v_scale
151
+ k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
152
+ self.kv_cache_quantized_dtype,
153
+ k_scale, v_scale)
154
+
155
+ with jax.named_scope("attn_op"):
156
+ new_kv_cache, outputs_TNH = self.attention(
157
+ is_prefill,
158
+ kv_cache,
159
+ q_TNH,
160
+ k_SKH,
161
+ v_SKH,
162
+ attention_metadata,
163
+ self.mesh,
164
+ q_scale=q_scale,
165
+ k_scale=k_scale,
166
+ v_scale=v_scale,
167
+ )
168
+
169
+ with jax.named_scope("o_proj"):
170
+ o_TD = jnp.einsum('TNH,NHD -> TD', outputs_TNH,
171
+ self.kernel_o_proj_NHD.value)
172
+ return new_kv_cache, o_TD
173
+
174
+ def attention(
175
+ self,
176
+ is_prefill: bool,
177
+ kv_cache: KVCache,
178
+ q_TNH: jax.Array,
179
+ k_SKH: jax.Array,
180
+ v_SKH: jax.Array,
181
+ attention_metadata: AttentionMetadata,
182
+ mesh: Mesh,
183
+ q_scale: float | None = None,
184
+ k_scale: float | None = None,
185
+ v_scale: float | None = None,
186
+ ) -> Tuple[KVCache, jax.Array]:
187
+ """Performs scaled dot-product attention and updates the KV cache.
188
+
189
+ This function handles the core attention logic, which varies between
190
+ prefill and generation modes. In prefill, it computes self-attention
191
+ over the input sequence with a causal mask. In generation, it attends
192
+ to the full history of keys and values stored in the cache.
193
+
194
+ Args:
195
+ is_prefill: A boolean indicating if the mode is 'prefill'.
196
+ kv_cache: The key-value cache to be updated and used.
197
+ q_TNH: Query tensor of shape `(query_seq, num_attention_heads, head_dim)`.
198
+ k_SKH: Key tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
199
+ v_SKH: Value tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
200
+ attention_metadata: Metadata containing sequence lengths.
201
+ mesh: The JAX device mesh (unused in this specific function but
202
+ kept for potential future use or API consistency).
203
+ q_scale: Quantization scale for q.
204
+ k_scale: Quantization scale for k.
205
+ v_scale: Quantization scale for v.
206
+
207
+ Returns:
208
+ A tuple containing:
209
+ - The updated KV cache.
210
+ - The attention output tensor of shape
211
+ `(seq, num_q_heads, head_dim)`.
212
+ """
213
+ md = attention_metadata
214
+ kv_cache_spec = P(None, None, "model")
215
+ in_specs = (
216
+ self.query_tnh, # q
217
+ self.keyvalue_skh, # k
218
+ self.keyvalue_skh, # v
219
+ kv_cache_spec, # kv_cache
220
+ P(), # md.seq_lens: Replicated
221
+ P(), # page_indices_flat: Replicated
222
+ P(), # query_start_loc: Replicated
223
+ P(), # distribution: Replicated
224
+ )
225
+
226
+ out_specs = (self.attn_o_tnh, kv_cache_spec)
227
+
228
+ def _ragged_paged_attention(*args):
229
+ return ragged_paged_attention(
230
+ *args,
231
+ sm_scale=q_TNH.shape[-1]**-0.5,
232
+ q_scale=q_scale,
233
+ k_scale=k_scale,
234
+ v_scale=v_scale,
235
+ )
236
+
237
+ output_TNH, kv_cache = jax.jit(
238
+ shard_map.shard_map(
239
+ _ragged_paged_attention,
240
+ mesh=mesh,
241
+ in_specs=in_specs,
242
+ out_specs=out_specs,
243
+ check_rep=False,
244
+ ))(
245
+ q_TNH,
246
+ k_SKH,
247
+ v_SKH,
248
+ kv_cache,
249
+ md.seq_lens,
250
+ md.block_tables,
251
+ md.query_start_loc,
252
+ md.request_distribution,
253
+ )
254
+ return kv_cache, output_TNH