tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,262 @@
1
+ from dataclasses import InitVar, dataclass
2
+ from typing import Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax import nnx
7
+ from flax.typing import Sharding
8
+ from jax.experimental import shard_map
9
+ from jax.sharding import Mesh
10
+ from jax.sharding import PartitionSpec as P
11
+
12
+ from tpu_inference import utils
13
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
14
+ ragged_paged_attention_hd64
15
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
16
+ from tpu_inference.layers.jax.base import create_param
17
+ from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
18
+
19
+ KVCache = Tuple[jax.Array, jax.Array]
20
+
21
+
22
+ @dataclass(kw_only=True)
23
+ class GptOssAttention(nnx.Module):
24
+ """
25
+ JAX implementation of the GPT-OSS Attention block
26
+ """
27
+ hidden_size: int
28
+ num_attention_heads: int
29
+ num_key_value_heads: int
30
+ head_dim: int
31
+ dtype: jnp.dtype
32
+ rngs: InitVar[nnx.Rngs]
33
+
34
+ rope_theta: float
35
+ initial_context_length: int = 4096
36
+ rope_scaling_factor: float = 32.0
37
+ rope_ntk_alpha: float = 1.0
38
+ rope_ntk_beta: float = 32.0
39
+ kv_cache_dtype: str
40
+
41
+ query_tnh: P = P()
42
+ keyvalue_skh: P = P()
43
+ attn_o_tnh: P = P()
44
+ dnh_sharding: Sharding = ()
45
+ dkh_sharding: Sharding = ()
46
+ nhd_sharding: Sharding = ()
47
+ n_sharding: Sharding = ()
48
+ nh_sharding: Sharding = ()
49
+ kh_sharding: Sharding = ()
50
+ d_sharding: Sharding = ()
51
+
52
+ random_init: bool = False
53
+ mesh: Mesh
54
+
55
+ _q_scale: float = 1.0
56
+ _k_scale: float = 1.0
57
+ _v_scale: float = 1.0
58
+ kv_cache_quantized_dtype = None
59
+
60
+ def __post_init__(self, rngs: nnx.Rngs):
61
+ """Initializes weights, biases, and RoPE module."""
62
+
63
+ self.sm_scale = 1.0 / (self.head_dim**0.5)
64
+
65
+ self.sinks_N = create_param(
66
+ rngs,
67
+ shape=(self.num_attention_heads, ),
68
+ dtype=jnp.float32,
69
+ sharding=self.n_sharding,
70
+ random_init=self.random_init,
71
+ )
72
+
73
+ # Q, K, V projection kernels
74
+ self.kernel_q_DNH = create_param(
75
+ rngs,
76
+ shape=(self.hidden_size, self.num_attention_heads, self.head_dim),
77
+ dtype=self.dtype,
78
+ sharding=self.dnh_sharding,
79
+ random_init=self.random_init,
80
+ )
81
+ self.bias_q_NH = create_param(
82
+ rngs,
83
+ shape=(self.num_attention_heads, self.head_dim),
84
+ dtype=self.dtype,
85
+ sharding=self.nh_sharding,
86
+ random_init=self.random_init,
87
+ )
88
+ self.kernel_k_DKH = create_param(
89
+ rngs,
90
+ shape=(self.hidden_size, self.num_key_value_heads, self.head_dim),
91
+ dtype=self.dtype,
92
+ sharding=self.dkh_sharding,
93
+ random_init=self.random_init,
94
+ )
95
+ self.bias_k_KH = create_param(
96
+ rngs,
97
+ shape=(self.num_key_value_heads, self.head_dim),
98
+ dtype=self.dtype,
99
+ sharding=self.kh_sharding,
100
+ random_init=self.random_init,
101
+ )
102
+ self.kernel_v_DKH = create_param(
103
+ rngs,
104
+ shape=(self.hidden_size, self.num_key_value_heads, self.head_dim),
105
+ dtype=self.dtype,
106
+ sharding=self.dkh_sharding,
107
+ random_init=self.random_init,
108
+ )
109
+ self.bias_v_KH = create_param(
110
+ rngs,
111
+ shape=(self.num_key_value_heads, self.head_dim),
112
+ dtype=self.dtype,
113
+ sharding=self.kh_sharding,
114
+ random_init=self.random_init,
115
+ )
116
+ # Output projection kernel
117
+ self.kernel_o_proj_NHD = create_param(
118
+ rngs,
119
+ shape=(self.num_attention_heads, self.head_dim, self.hidden_size),
120
+ dtype=self.dtype,
121
+ sharding=self.nhd_sharding,
122
+ random_init=self.random_init,
123
+ )
124
+ self.bias_o_D = create_param(
125
+ rngs,
126
+ shape=(self.hidden_size, ),
127
+ dtype=self.dtype,
128
+ sharding=self.d_sharding,
129
+ random_init=self.random_init,
130
+ )
131
+
132
+ # RoPE Module
133
+ self.rope = GptOssRotaryEmbedding(
134
+ head_dim=self.head_dim,
135
+ rope_theta=self.rope_theta,
136
+ dtype=self.dtype,
137
+ initial_context_length=self.initial_context_length,
138
+ rope_scaling_factor=self.rope_scaling_factor,
139
+ rope_ntk_alpha=self.rope_ntk_alpha,
140
+ rope_ntk_beta=self.rope_ntk_beta)
141
+
142
+ if self.kv_cache_dtype != "auto":
143
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
144
+ self.kv_cache_dtype)
145
+
146
+ def attention(
147
+ self,
148
+ kv_cache: KVCache,
149
+ q_TNH: jax.Array,
150
+ k_SKH: jax.Array,
151
+ v_SKH: jax.Array,
152
+ sinks: jax.Array,
153
+ attention_metadata: AttentionMetadata,
154
+ mesh: Mesh,
155
+ q_scale: float | None = None,
156
+ k_scale: float | None = None,
157
+ v_scale: float | None = None,
158
+ ) -> Tuple[KVCache, jax.Array]:
159
+ """Performs scaled dot-product attention by calling the ragged_paged_attention kernel."""
160
+ md = attention_metadata
161
+ kv_cache_spec = P(None, None, "model")
162
+
163
+ in_specs = (
164
+ self.query_tnh, # q
165
+ self.keyvalue_skh, # k
166
+ self.keyvalue_skh, # v
167
+ kv_cache_spec, # kv_cache
168
+ P(), # md.seq_lens: Replicated
169
+ P(), # page_indices_flat: Replicated
170
+ P(), # query_start_loc: Replicated
171
+ P(), # distribution: Replicated
172
+ P(('model')), # sinks
173
+ )
174
+ out_specs = (self.attn_o_tnh, kv_cache_spec)
175
+
176
+ def _ragged_paged_attention_wrapper(*args):
177
+ # Pass the GPT-OSS specific parameters to the kernel
178
+ return ragged_paged_attention_hd64(
179
+ *args,
180
+ sm_scale=self.sm_scale,
181
+ sliding_window=md.sliding_window,
182
+ q_scale=q_scale,
183
+ k_scale=k_scale,
184
+ v_scale=v_scale,
185
+ )
186
+
187
+ output_TNH, kv_cache = jax.jit(
188
+ shard_map.shard_map(
189
+ _ragged_paged_attention_wrapper,
190
+ mesh=mesh,
191
+ in_specs=in_specs,
192
+ out_specs=out_specs,
193
+ check_rep=False,
194
+ ))(
195
+ q_TNH,
196
+ k_SKH,
197
+ v_SKH,
198
+ kv_cache,
199
+ md.seq_lens,
200
+ md.block_tables,
201
+ md.query_start_loc,
202
+ md.request_distribution,
203
+ sinks,
204
+ )
205
+ return kv_cache, output_TNH
206
+
207
+ def __call__(self,
208
+ x_TD,
209
+ is_prefill,
210
+ kv_cache: KVCache,
211
+ attention_metadata: AttentionMetadata,
212
+ use_attention_rope: bool = True):
213
+ """Forward pass for the Attention module using 3D kernels."""
214
+ md = attention_metadata
215
+ x_TD = jnp.asarray(x_TD, self.dtype)
216
+
217
+ with jax.named_scope("q_proj"):
218
+ q_TNH = jnp.einsum("TD,DNH->TNH", x_TD, self.kernel_q_DNH.value)
219
+ q_TNH += self.bias_q_NH.value
220
+
221
+ with jax.named_scope("k_proj"):
222
+ k_TKH = jnp.einsum("TD,DKH->TKH", x_TD, self.kernel_k_DKH.value)
223
+ k_TKH += self.bias_k_KH.value
224
+
225
+ with jax.named_scope("v_proj"):
226
+ v_TKH = jnp.einsum("TD,DKH->TKH", x_TD, self.kernel_v_DKH.value)
227
+ v_TKH += self.bias_v_KH.value
228
+
229
+ if use_attention_rope:
230
+ q_TNH, k_TKH = self.rope(q_TNH, k_TKH, md.input_positions)
231
+
232
+ q_scale = k_scale = v_scale = None
233
+ if self.kv_cache_quantized_dtype:
234
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
235
+ # q_scale = self._q_scale
236
+ k_scale = self._k_scale
237
+ v_scale = self._v_scale
238
+ k_TKH, v_TKH = utils.quantize_kv(k_TKH, v_TKH,
239
+ self.kv_cache_quantized_dtype,
240
+ k_scale, v_scale)
241
+
242
+ with jax.named_scope("attn_op"):
243
+ new_kv_cache, attn_out_TNH = self.attention(
244
+ kv_cache=kv_cache,
245
+ q_TNH=q_TNH,
246
+ k_SKH=k_TKH,
247
+ v_SKH=v_TKH,
248
+ sinks=self.sinks_N.value,
249
+ attention_metadata=md,
250
+ mesh=self.mesh,
251
+ q_scale=q_scale,
252
+ k_scale=k_scale,
253
+ v_scale=v_scale,
254
+ )
255
+ attn_out_TNH = attn_out_TNH[..., :self.head_dim]
256
+
257
+ with jax.named_scope("o_proj"):
258
+ output_TD = jnp.einsum("TNH,NHD->TD", attn_out_TNH,
259
+ self.kernel_o_proj_NHD.value)
260
+ output_TD += self.bias_o_D.value
261
+
262
+ return new_kv_cache, output_TD
@@ -0,0 +1,153 @@
1
+ from dataclasses import dataclass
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.sharding import Sharding
7
+
8
+ from tpu_inference import utils
9
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
10
+ from tpu_inference.layers.jax.attention.attention import Attention, KVCache
11
+ from tpu_inference.layers.jax.rope_interface import apply_rope
12
+ from tpu_inference.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class L2Norm(nnx.Module):
18
+ """
19
+ Implementation of L2 Norm in JAX (taken from MaxText repo - maxtext/MaxText/layers/attentions.py).
20
+
21
+ Attributes:
22
+ eps: float, epsilon used for numerical stability (default value should be ok for most cases).
23
+ """
24
+
25
+ def __init__(self, eps: float = 1e-6):
26
+ self.eps = eps
27
+
28
+ def __call__(self, x):
29
+ return x * jax.lax.rsqrt(
30
+ jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
31
+
32
+
33
+ @dataclass(kw_only=True)
34
+ class Llama4Attention(Attention):
35
+ use_qk_norm: bool
36
+ temperature_tuning: bool
37
+ temperature_tuning_floor_scale: float
38
+ temperature_tuning_scale: float
39
+ activation_attention_td: Sharding
40
+ activation_attention_out_td: Sharding
41
+
42
+ def __call__(self,
43
+ x,
44
+ is_prefill,
45
+ kv_cache: KVCache,
46
+ attention_metadata: AttentionMetadata,
47
+ use_attention_rope: bool = True):
48
+ """Performs the forward pass of the attention module.
49
+
50
+ This method computes the attention output by projecting the input `x`
51
+ to queries, keys, and values, applying RoPE and L2Norm if specified,
52
+ performing scaled dot-product attention, and projecting the results
53
+ back to the model dimension.
54
+ If no RoPE (NoPE) is specified, one can also perform temperature tuning
55
+ which is useful to combat dilution of attention scores in long-context attention.
56
+
57
+ Args:
58
+ x: The input tensor of shape `(seq_len, d_model)`.
59
+ is_prefill: Whether the operation mode is prefill (otherwise it is generate).
60
+ kv_cache: The key-value cache for storing past attention states.
61
+ attention_metadata: Metadata for attention, such as input positions.
62
+ use_attention_rope: Whether to use RoPE.
63
+
64
+ Returns:
65
+ A tuple containing:
66
+ - The updated KV cache.
67
+ - The attention output tensor of shape
68
+ `(batch_size, seq_len, d_model)`.
69
+ """
70
+ md = attention_metadata
71
+ x = jnp.asarray(x, self.dtype)
72
+ x_SD = nnx.with_sharding_constraint(x, self.activation_attention_td)
73
+ x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
74
+ rope_scaling = self.rope_scaling
75
+ rope_theta = self.rope_theta
76
+ H = self.head_dim
77
+ l2_norm = L2Norm()
78
+
79
+ with jax.named_scope("q_proj"):
80
+ q_TNH = jnp.einsum('TD,DNH -> TNH', x_q_TD,
81
+ self.kernel_q_proj_DNH.value)
82
+ if use_attention_rope:
83
+ q_TNH = apply_rope(q_TNH, md.input_positions, H, rope_theta,
84
+ rope_scaling, self.rope_input_ordering)
85
+
86
+ # Apply normaliation after RoPE
87
+ if self.use_qk_norm:
88
+ q_TNH = l2_norm(q_TNH)
89
+ else:
90
+ if self.temperature_tuning:
91
+ q_TNH = self.apply_temperature_tuning(md, q_TNH)
92
+
93
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
94
+ with jax.named_scope("k_proj"):
95
+ k_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
96
+ self.kernel_k_proj_DKH.value)
97
+ if use_attention_rope:
98
+ k_SKH = apply_rope(k_SKH, md.input_positions, H, rope_theta,
99
+ rope_scaling, self.rope_input_ordering)
100
+
101
+ # Apply normaliation after RoPE
102
+ if self.use_qk_norm:
103
+ k_SKH = l2_norm(k_SKH)
104
+ k_SKH = nnx.with_sharding_constraint(k_SKH, self.keyvalue_skh)
105
+
106
+ with jax.named_scope("v_proj"):
107
+ v_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
108
+ self.kernel_v_proj_DKH.value)
109
+ v_SKH = nnx.with_sharding_constraint(v_SKH, self.keyvalue_skh)
110
+
111
+ q_scale = k_scale = v_scale = None
112
+ if self.kv_cache_quantized_dtype:
113
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
114
+ # q_scale = self._q_scale
115
+ k_scale = self._k_scale
116
+ v_scale = self._v_scale
117
+ k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
118
+ self.kv_cache_quantized_dtype,
119
+ k_scale, v_scale)
120
+
121
+ with jax.named_scope("attn_op"):
122
+ new_kv_cache, outputs_TNH = self.attention(
123
+ is_prefill,
124
+ kv_cache,
125
+ q_TNH,
126
+ k_SKH,
127
+ v_SKH,
128
+ attention_metadata,
129
+ self.mesh,
130
+ q_scale=q_scale,
131
+ k_scale=k_scale,
132
+ v_scale=v_scale,
133
+ )
134
+
135
+ with jax.named_scope("o_proj"):
136
+ o_TD = jnp.einsum('TNH,NHD -> TD', outputs_TNH,
137
+ self.kernel_o_proj_NHD.value)
138
+ o_TD = nnx.with_sharding_constraint(
139
+ o_TD, self.activation_attention_out_td)
140
+ return new_kv_cache, o_TD
141
+
142
+ def apply_temperature_tuning(self, md: AttentionMetadata,
143
+ input_arr_TNH: jax.Array) -> jax.Array:
144
+ """Applies temperature tuning to the input array of shape (T, N, H).
145
+ Args:
146
+ md: AttentionMetadata object containing the input positions.
147
+ input_arr_TNH: Input array of shape (T, N, H) which will have scaled temperatures applied.
148
+ """
149
+ attn_scales = (jnp.log(
150
+ jnp.floor((md.input_positions.astype(self.dtype) + 1.0) /
151
+ self.temperature_tuning_floor_scale) + 1.0) *
152
+ self.temperature_tuning_scale + 1.0)
153
+ return input_arr_TNH * attn_scales[:, None, None]
@@ -0,0 +1,151 @@
1
+ import dataclasses
2
+ from dataclasses import dataclass, fields
3
+ from typing import Any, Callable, Mapping
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax import nnx
8
+ from flax.typing import Sharding
9
+ from jax.sharding import PartitionSpec as P
10
+
11
+ from tpu_inference.logger import init_logger
12
+
13
+ # Type alias for Initializer for cleaner type hints
14
+ Initializer = Callable[..., jax.Array]
15
+ logger = init_logger(__name__)
16
+
17
+ # Define singleton initializers to avoid re-compilation.
18
+ _scale_initializer = nnx.initializers.ones
19
+ _sharded_initializer = nnx.initializers.xavier_normal()
20
+ _init_fn = nnx.initializers.uniform()
21
+
22
+
23
+ @dataclass
24
+ class Config:
25
+ """Base configuration class with a robust factory method.
26
+
27
+ This class provides a `from_cfg` classmethod that allows creating a config
28
+ instance from a dictionary, ensuring that all required fields are present
29
+ and ignoring any extraneous keys.
30
+ """
31
+
32
+ @classmethod
33
+ def from_cfg(cls, cfg: dict[str, Any] | None = None, **kwargs):
34
+ """Creates a config instance from a dictionary and/or keyword arguments.
35
+
36
+ This factory method validates that all fields without default values
37
+ are provided in the input dictionary or keyword arguments.
38
+
39
+ Args:
40
+ cfg: A dictionary of configuration parameters.
41
+ **kwargs: Additional configuration parameters passed as keyword arguments.
42
+
43
+ Returns:
44
+ An instance of the configuration class.
45
+
46
+ Raises:
47
+ ValueError: If any required parameters are missing.
48
+ """
49
+ if cfg is None:
50
+ cfg = {}
51
+ cfg.update(kwargs)
52
+
53
+ required_params = {
54
+ f.name
55
+ for f in fields(cls) if f.default is dataclasses.MISSING
56
+ and f.default_factory is dataclasses.MISSING
57
+ }
58
+
59
+ # Check if any of the truly required parameters are missing from the provided config.
60
+ missing_params = required_params - set(cfg.keys())
61
+ if missing_params:
62
+ raise ValueError(
63
+ f"Missing required parameters for {cls.__name__}: {', '.join(sorted(list(missing_params)))}"
64
+ )
65
+
66
+ known_params = {f.name for f in fields(cls)}
67
+ filtered_cfg = {k: v for k, v in cfg.items() if k in known_params}
68
+
69
+ return cls(**filtered_cfg)
70
+
71
+ # TODO: check logic with some unit tests.
72
+ def maybe_apply_overrides(self):
73
+ """Update the args with additional_configs, hf_overrides, and override_generation_config settings.
74
+ If there is overlap in overrides between the configs, then print a warning declaring which
75
+ overrides will take precedent."""
76
+
77
+ if not getattr(self, "vllm_config"):
78
+ return
79
+
80
+ def _overrides_str(original: str, original_val: Any,
81
+ new_val: Any) -> str:
82
+ return f"{original}: {original_val} ---> {new_val}"
83
+
84
+ def _get_overrides_dict(self) -> Mapping[str, Any]:
85
+ """Return the overrides from all of the possible vllm sections."""
86
+ overrides_dict = {}
87
+ vllm_model_config = self.vllm_config.model_config
88
+
89
+ for override_type in ordered_override_types:
90
+ if override_type == "additional_config":
91
+ overrides_dict[
92
+ override_type] = self.vllm_config.additional_config
93
+ else:
94
+ overrides_dict[override_type] = getattr(
95
+ vllm_model_config, override_type)
96
+ return overrides_dict
97
+
98
+ ordered_override_types = [
99
+ "additional_config", "hf_overrides", "override_generation_config"
100
+ ]
101
+
102
+ overrides_dict = _get_overrides_dict(self)
103
+
104
+ # Override the config values using the vLLM sections with highest
105
+ # precedence first.
106
+ for field in fields(self):
107
+ selected_type = None
108
+ for override_type in reversed(ordered_override_types):
109
+ if field.name in overrides_dict[override_type]:
110
+ setattr(self, field.name,
111
+ overrides_dict[override_type][field.name])
112
+ selected_type = override_type
113
+ break
114
+ if selected_type is None:
115
+ continue
116
+
117
+ # If multiple vLLM sections contain overrides, print a warning.
118
+ for override_type in ordered_override_types:
119
+ if override_type == selected_type:
120
+ break
121
+ else:
122
+ if field.name in overrides_dict[override_type]:
123
+ overriden_keys_str = _overrides_str(
124
+ field.name,
125
+ overrides_dict[override_type][field.name],
126
+ overrides_dict[selected_type][field.name])
127
+ logger.warning(
128
+ f"Overriding {override_type} arguments with the following {selected_type} args: {overriden_keys_str}"
129
+ )
130
+
131
+ def __post_init__(self):
132
+ self.maybe_apply_overrides()
133
+
134
+
135
+ def create_param(rngs: nnx.Rngs,
136
+ shape: tuple[int, ...],
137
+ sharding: Sharding = (),
138
+ dtype: Any = jnp.float32,
139
+ random_init=False) -> nnx.Param:
140
+ key = rngs.params()
141
+ if random_init:
142
+ initializer = _scale_initializer if len(
143
+ shape) == 1 else _sharded_initializer
144
+
145
+ jitted_initializer = jax.jit(initializer,
146
+ static_argnames=('shape', 'dtype'),
147
+ out_shardings=P(*sharding))
148
+ param_data = jitted_initializer(key, shape, dtype)
149
+ return nnx.Param(param_data, sharding=sharding)
150
+ else:
151
+ return nnx.Param(_init_fn(key, shape, dtype), sharding=sharding)
@@ -0,0 +1,88 @@
1
+ """
2
+ Current Used Abbreviation for Tensor Dimensions:
3
+ B: Batch size
4
+ T: Sequence Length (for Query tensors)
5
+ S: Sequence Length (for Key/Value tensors)
6
+ D: d_model, the embedding dimension of the model
7
+ F: d_ff, the hidden dimension of the feed-forward MLP layers
8
+ V: Vocab Size
9
+ H: Dimension of each attention head
10
+ N: Number of query heads in Attention
11
+ Q: Number of query heads (synonymous with N)
12
+ K: Number of Key/Value heads in Attention
13
+ C: Expert capacity in Mixture-of-Experts models
14
+ X: Number of activated experts per token in MoE
15
+ G: Number of groups in Grouped-Query Attention
16
+ E: Total number of experts in MoE
17
+ """
18
+
19
+ import enum
20
+ from typing import Tuple, TypeAlias
21
+
22
+ import jax
23
+
24
+ KVCacheType: TypeAlias = Tuple[jax.Array, jax.Array]
25
+
26
+
27
+ class RouterType(enum.Enum):
28
+ """Enum for router types."""
29
+ TOP_K = 'top_k'
30
+
31
+
32
+ class OPERATION_MODE(enum.Enum):
33
+ PREFILL = 1
34
+ DECODE = 2
35
+
36
+
37
+ class HuggingFaceArgNames(enum.Enum):
38
+ ## Modeling params
39
+ HIDDEN_ACT: str = "hidden_act"
40
+ HIDDEN_SIZE: str = "hidden_size"
41
+ NUM_HIDDEN_LAYERS: str = "num_hidden_layers"
42
+ RMS_NORM_EPS: str = "rms_norm_eps"
43
+ ROPE_SCALING: str = "rope_scaling"
44
+ ROPE_THETA: str = "rope_theta"
45
+ VOCAB_SIZE: str = "vocab_size"
46
+
47
+ # Block parameters
48
+ SHARED_EXPERTS: str = "shared_experts"
49
+
50
+ # FFW params
51
+ INTERMEDIATE_SIZE: str = "intermediate_size"
52
+
53
+ # Attention params
54
+ HEAD_DIM: str = "head_dim"
55
+ NUM_ATTENTION_HEADS: str = "num_attention_heads"
56
+ NUM_KEY_VALUE_HEADS: str = "num_key_value_heads"
57
+ ATTENTION_DROPOUT: str = "attention_dropout"
58
+ ATTENTION_BIAS: str = "attention_bias"
59
+ ATTENTION_CHUNK_SIZE: str = "attention_chunk_size"
60
+
61
+ ## Llama4 Attention Params
62
+ USE_QK_NORM: str = "use_qk_norm"
63
+ TEMPERATURE_TUNING: str = "temperature_tuning"
64
+ TEMPERATURE_TUNING_SCALE: str = "temperature_tuning_scale"
65
+ TEMPERATURE_TUNING_FLOOR_SCALE: str = "temperature_tuning_floor_scale"
66
+
67
+ # MLA params
68
+ KV_LORA_RANK: str = "kv_lora_rank"
69
+ Q_LORA_RANK: str = "q_lora_rank"
70
+ QK_NOPE_HEAD_DIM: str = "qk_nope_head_dim"
71
+ QK_ROPE_HEAD_DIM: str = "qk_rope_head_dim"
72
+ V_HEAD_DIM: str = "v_head_dim"
73
+
74
+ # MoE
75
+ INTERMEDIATE_SIZE_MOE: str = "intermediate_size_moe"
76
+ NUM_LOCAL_EXPERTS: str = "num_local_experts" # Llama moe
77
+ NUM_EXPERTS_PER_TOKEN: str = "num_experts_per_token"
78
+ NUM_ROUTED_EXPERTS: str = "n_routed_experts" # Deepseek moe
79
+ NUM_SHARED_ROUTED_EXPERTS: str = "n_shared_experts"
80
+ NUM_GROUPS: str = "n_group"
81
+ ROUTED_SCALING_FACTOR: str = "routed_scaling_factor"
82
+ TOPK_GROUP: str = "topk_group"
83
+ NORM_TOPK_PROB: str = "norm_topk_prob"
84
+ SCORING_FUNCTION: str = "scoring_func"
85
+
86
+ ## Sampling params
87
+ BOS_TOKEN_ID: str = "bos_token_id"
88
+ EOS_TOKEN_ID: str = "eos_token_id"