tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,258 @@
1
+ # TODO: Update documentation
2
+
3
+ from typing import List, Optional, Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax import nnx
8
+ from jax.sharding import Mesh
9
+ from jax.sharding import PartitionSpec as P
10
+ from vllm.config import VllmConfig
11
+
12
+ from tpu_inference.layers.jax.attention.attention import (Attention,
13
+ AttentionMetadata)
14
+ from tpu_inference.layers.jax.constants import KVCacheType
15
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
16
+ from tpu_inference.layers.jax.transformer_block import TransformerBlock
17
+ from tpu_inference.logger import init_logger
18
+ from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
19
+ load_hf_weights)
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class LlamaForCausalLM(nnx.Module):
25
+
26
+ def __init__(self,
27
+ vllm_config: VllmConfig,
28
+ rng: jax.Array,
29
+ mesh: Mesh,
30
+ force_random_weights: bool = False):
31
+ assert mesh is not None
32
+
33
+ self.vllm_config = vllm_config
34
+ self.rng = nnx.Rngs(rng)
35
+ self.mesh = mesh
36
+
37
+ model_name = self.vllm_config.model_config.model.lower()
38
+ if "70b" in model_name:
39
+ logger.info("Initializing Llama3 70B model variant.")
40
+ self.hidden_size = 8192
41
+ num_layers = 80
42
+ self.num_attention_heads = 64
43
+ self.num_key_value_heads = 8
44
+ intermediate_size = 28672
45
+ elif "8b" in model_name:
46
+ logger.info("Initializing Llama3 8B model variant.")
47
+ self.hidden_size = 4096
48
+ num_layers = 32
49
+ self.num_attention_heads = 32
50
+ self.num_key_value_heads = 8
51
+ intermediate_size = 14336
52
+ else:
53
+ raise ValueError(
54
+ f"Could not determine Llama3 variant (8B or 70B) from model name: '{model_name}'. "
55
+ "Please ensure '8b' or '70b' is in the model path.")
56
+
57
+ dtype = jnp.bfloat16
58
+ self.head_dim = 128
59
+ rope_theta = 500000.0
60
+ vocab_size = 128256
61
+ rms_norm_eps = 1e-5
62
+
63
+ self.embedder = Embedder(vocab_size=vocab_size,
64
+ hidden_size=self.hidden_size,
65
+ dtype=dtype,
66
+ rngs=self.rng,
67
+ random_init=force_random_weights,
68
+ vd_sharding=("model", None))
69
+
70
+ self.layers = []
71
+ kv_cache_dtype = self.vllm_config.cache_config.cache_dtype
72
+ for _ in range(num_layers):
73
+ self.layers.append(
74
+ TransformerBlock(
75
+ pre_attention_norm=RMSNorm(
76
+ dims=self.hidden_size,
77
+ random_init=force_random_weights,
78
+ epsilon=rms_norm_eps,
79
+ rngs=self.rng,
80
+ with_scale=True,
81
+ dtype=dtype,
82
+ ),
83
+ pre_mlp_norm=RMSNorm(
84
+ dims=self.hidden_size,
85
+ rngs=self.rng,
86
+ random_init=force_random_weights,
87
+ epsilon=rms_norm_eps,
88
+ with_scale=True,
89
+ dtype=dtype,
90
+ ),
91
+ attn=Attention(
92
+ hidden_size=self.hidden_size,
93
+ num_attention_heads=self.num_attention_heads,
94
+ num_key_value_heads=self.num_key_value_heads,
95
+ head_dim=self.head_dim,
96
+ rope_theta=rope_theta,
97
+ rope_scaling={},
98
+ rngs=self.rng,
99
+ dtype=dtype,
100
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
101
+ kv_cache_dtype=kv_cache_dtype,
102
+ mesh=self.mesh,
103
+ random_init=force_random_weights,
104
+ dnh_sharding=(None, "model", None),
105
+ dkh_sharding=(None, "model", None),
106
+ nhd_sharding=("model", None, None),
107
+ query_tnh=P(None, "model", None),
108
+ keyvalue_skh=P(None, "model", None),
109
+ attn_o_tnh=P(None, "model", None),
110
+ ),
111
+ custom_module=DenseFFW(dtype=dtype,
112
+ hidden_act="silu",
113
+ hidden_size=self.hidden_size,
114
+ intermediate_size=intermediate_size,
115
+ rngs=self.rng,
116
+ df_sharding=(None, "model"),
117
+ fd_sharding=("model", None),
118
+ random_init=force_random_weights),
119
+ ))
120
+
121
+ self.final_norm = RMSNorm(
122
+ dims=self.hidden_size,
123
+ rngs=self.rng,
124
+ random_init=force_random_weights,
125
+ epsilon=rms_norm_eps,
126
+ with_scale=True,
127
+ dtype=dtype,
128
+ )
129
+
130
+ self.lm_head = LMhead(vocab_size=vocab_size,
131
+ hidden_size=self.hidden_size,
132
+ dtype=dtype,
133
+ rngs=self.rng,
134
+ dv_sharding=(None, 'model'),
135
+ random_init=force_random_weights)
136
+
137
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
138
+ # NOTE: Since we are using nnx.eval_shape to init the model,
139
+ # we have to pass dynamic arrays here for __call__'s usage.
140
+ self.rng = nnx.Rngs(rng)
141
+ weight_loader = Llama3WeightLoader(
142
+ vllm_config=self.vllm_config,
143
+ hidden_size=self.hidden_size,
144
+ attn_heads=self.num_attention_heads,
145
+ num_key_value_heads=self.num_key_value_heads,
146
+ attn_head_dim=self.head_dim)
147
+
148
+ weight_loader.load_weights(self)
149
+
150
+ def __call__(
151
+ self,
152
+ kv_caches: List[jax.Array],
153
+ input_ids: jax.Array,
154
+ attention_metadata: AttentionMetadata,
155
+ *args,
156
+ ) -> Tuple[List[KVCacheType], jax.Array]:
157
+ is_prefill = False
158
+ with jax.named_scope("llama_embed_input"): #Embedding
159
+ x_TD = self.embedder.encode(input_ids)
160
+
161
+ with jax.named_scope("llama_model_transformer_blocks"):
162
+ for (i, layer) in enumerate(self.layers):
163
+ kv_cache = kv_caches[i]
164
+
165
+ # The first layer is unscoped to avoid JAX tracing issues.
166
+ # JAX's profiler may incorrectly apply the scope name from the first
167
+ # layer's kernel compilation to all subsequent layers. Skipping the
168
+ # first layer ensures distinct scope names for the remaining layers.
169
+ if i == 0:
170
+ new_kv_cache, x_TD = layer(x_TD, is_prefill, kv_cache,
171
+ attention_metadata)
172
+ else:
173
+ with jax.named_scope(f'layer_{i}'):
174
+ new_kv_cache, x_TD = layer(x_TD, is_prefill, kv_cache,
175
+ attention_metadata)
176
+
177
+ kv_caches[i] = new_kv_cache
178
+
179
+ with jax.named_scope(
180
+ "llama_final_norm"): #Norm after last transformer block
181
+ final_activation_TD = self.final_norm(x_TD)
182
+
183
+ return kv_caches, final_activation_TD, []
184
+
185
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
186
+ with jax.named_scope("llama_lm_head_projection"
187
+ ): #LM head projection to produce logits
188
+ logits_TV = jnp.dot(hidden_states,
189
+ self.lm_head.input_embedding_table_DV.value)
190
+
191
+ return logits_TV
192
+
193
+
194
+ class Llama3WeightLoader:
195
+
196
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
197
+ num_key_value_heads, attn_head_dim):
198
+ self._transpose_map = {
199
+ "lm_head": (1, 0),
200
+ "gate_proj": (1, 0),
201
+ "up_proj": (1, 0),
202
+ "down_proj": (1, 0),
203
+ "q_proj": (2, 0, 1),
204
+ "k_proj": (2, 0, 1),
205
+ "v_proj": (2, 0, 1),
206
+ "o_proj": (1, 2, 0),
207
+ }
208
+ self._weight_shape_map = {
209
+ "q_proj": (attn_heads, -1, hidden_size),
210
+ "k_proj": (num_key_value_heads, -1, hidden_size),
211
+ "v_proj": (num_key_value_heads, -1, hidden_size),
212
+ "o_proj": (hidden_size, attn_heads, -1),
213
+ }
214
+ self._bias_shape_map = {
215
+ "q_proj.bias": (attn_heads, attn_head_dim),
216
+ "k_proj.bias": (num_key_value_heads, attn_head_dim),
217
+ "v_proj.bias": (num_key_value_heads, attn_head_dim)
218
+ }
219
+
220
+ # Set the mappings from loaded parameter keys to standardized names.
221
+ self._loaded_to_standardized_keys = {
222
+ "model.embed_tokens": "embedder.input_embedding_table_VD",
223
+ "model.layers.*.input_layernorm":
224
+ "layers.*.pre_attention_norm.scale",
225
+ "model.layers.*.mlp.down_proj":
226
+ "layers.*.custom_module.kernel_down_proj_FD",
227
+ "model.layers.*.mlp.gate_proj":
228
+ "layers.*.custom_module.kernel_gating_DF",
229
+ "model.layers.*.mlp.up_proj":
230
+ "layers.*.custom_module.kernel_up_proj_DF",
231
+ "model.layers.*.post_attention_layernorm":
232
+ "layers.*.pre_mlp_norm.scale",
233
+ "model.layers.*.self_attn.k_proj":
234
+ "layers.*.attn.kernel_k_proj_DKH",
235
+ "model.layers.*.self_attn.o_proj":
236
+ "layers.*.attn.kernel_o_proj_NHD",
237
+ "model.layers.*.self_attn.q_proj":
238
+ "layers.*.attn.kernel_q_proj_DNH",
239
+ "model.layers.*.self_attn.v_proj":
240
+ "layers.*.attn.kernel_v_proj_DKH",
241
+ "model.norm": "final_norm.scale",
242
+ "lm_head": "lm_head.input_embedding_table_DV"
243
+ }
244
+ self.vllm_config = vllm_config
245
+
246
+ def load_weights(self, model_for_loading: nnx.Module):
247
+ model_params = nnx.state(model_for_loading)
248
+ metadata_map = MetadataMap(name_map=self._loaded_to_standardized_keys,
249
+ reshape_map=self._weight_shape_map,
250
+ bias_reshape_map=self._bias_shape_map,
251
+ transpose_map=self._transpose_map)
252
+ load_hf_weights(vllm_config=self.vllm_config,
253
+ model=model_for_loading,
254
+ metadata_map=metadata_map,
255
+ mesh=model_for_loading.mesh)
256
+
257
+ # TODO: validate that all of the model_params were accounted for as well.
258
+ nnx.update(model_for_loading, model_params)
File without changes
File without changes