tpu-inference 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (123) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  53. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  54. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  55. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  56. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  57. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  58. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  59. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  60. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  61. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  63. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  65. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  67. tpu_inference/logger.py +10 -0
  68. tpu_inference/lora/__init__.py +0 -0
  69. tpu_inference/lora/torch_lora_ops.py +103 -0
  70. tpu_inference/lora/torch_punica_tpu.py +308 -0
  71. tpu_inference/mock/__init__.py +0 -0
  72. tpu_inference/mock/vllm_config_utils.py +28 -0
  73. tpu_inference/mock/vllm_envs.py +1233 -0
  74. tpu_inference/mock/vllm_logger.py +212 -0
  75. tpu_inference/mock/vllm_logging_utils.py +15 -0
  76. tpu_inference/models/__init__.py +0 -0
  77. tpu_inference/models/jax/__init__.py +0 -0
  78. tpu_inference/models/jax/deepseek_v3.py +868 -0
  79. tpu_inference/models/jax/llama3.py +366 -0
  80. tpu_inference/models/jax/llama4.py +473 -0
  81. tpu_inference/models/jax/llama_eagle3.py +333 -0
  82. tpu_inference/models/jax/phi3.py +376 -0
  83. tpu_inference/models/jax/qwen2.py +375 -0
  84. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  85. tpu_inference/models/jax/qwen3.py +302 -0
  86. tpu_inference/models/jax/utils/__init__.py +0 -0
  87. tpu_inference/models/jax/utils/file_utils.py +96 -0
  88. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  89. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  90. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  91. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  92. tpu_inference/models/vllm/__init__.py +0 -0
  93. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  94. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  95. tpu_inference/platforms/__init__.py +2 -0
  96. tpu_inference/platforms/tpu_jax.py +257 -0
  97. tpu_inference/runner/__init__.py +0 -0
  98. tpu_inference/runner/block_table_jax.py +122 -0
  99. tpu_inference/runner/compilation_manager.py +672 -0
  100. tpu_inference/runner/input_batch_jax.py +435 -0
  101. tpu_inference/runner/kv_cache.py +119 -0
  102. tpu_inference/runner/kv_cache_manager.py +460 -0
  103. tpu_inference/runner/lora_utils.py +92 -0
  104. tpu_inference/runner/multimodal_manager.py +208 -0
  105. tpu_inference/runner/persistent_batch_manager.py +244 -0
  106. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  107. tpu_inference/runner/structured_decoding_manager.py +89 -0
  108. tpu_inference/runner/tpu_jax_runner.py +771 -0
  109. tpu_inference/runner/utils.py +426 -0
  110. tpu_inference/spec_decode/__init__.py +0 -0
  111. tpu_inference/spec_decode/jax/__init__.py +0 -0
  112. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  113. tpu_inference/tpu_info.py +77 -0
  114. tpu_inference/utils.py +294 -0
  115. tpu_inference/worker/__init__.py +0 -0
  116. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  117. tpu_inference/worker/base.py +100 -0
  118. tpu_inference/worker/tpu_worker_jax.py +321 -0
  119. tpu_inference-0.11.1rc1.dist-info/METADATA +101 -0
  120. tpu_inference-0.11.1rc1.dist-info/RECORD +123 -0
  121. tpu_inference-0.11.1rc1.dist-info/WHEEL +5 -0
  122. tpu_inference-0.11.1rc1.dist-info/licenses/LICENSE +201 -0
  123. tpu_inference-0.11.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,375 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.sharding import Mesh
7
+ from transformers import Qwen2Config, modeling_flax_utils
8
+ from vllm.config import VllmConfig
9
+
10
+ from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
+ from tpu_inference.layers.jax.attention_interface import attention
13
+ from tpu_inference.layers.jax.rope_interface import apply_rope
14
+ from tpu_inference.logger import init_logger
15
+ from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
16
+ load_hf_weights)
17
+
18
+ logger = init_logger(__name__)
19
+
20
+ init_fn = nnx.initializers.uniform()
21
+
22
+
23
+ class Qwen2MLP(nnx.Module):
24
+
25
+ def __init__(self, config: Qwen2Config, dtype: jnp.dtype, rng: nnx.Rngs):
26
+ hidden_size = config.hidden_size
27
+ intermediate_size = config.intermediate_size
28
+ act = config.hidden_act
29
+
30
+ self.gate_proj = nnx.Linear(
31
+ hidden_size,
32
+ intermediate_size,
33
+ use_bias=False,
34
+ param_dtype=dtype,
35
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
36
+ rngs=rng,
37
+ )
38
+ self.up_proj = nnx.Linear(
39
+ hidden_size,
40
+ intermediate_size,
41
+ use_bias=False,
42
+ param_dtype=dtype,
43
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
44
+ rngs=rng,
45
+ )
46
+ self.down_proj = nnx.Linear(
47
+ intermediate_size,
48
+ hidden_size,
49
+ use_bias=False,
50
+ param_dtype=dtype,
51
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
52
+ rngs=rng,
53
+ )
54
+ self.act_fn = modeling_flax_utils.ACT2FN[act]
55
+
56
+ def __call__(self, x: jax.Array) -> jax.Array:
57
+ gate = self.act_fn(self.gate_proj(x))
58
+ up = self.up_proj(x)
59
+ fuse = gate * up
60
+ result = self.down_proj(fuse)
61
+ return result
62
+
63
+
64
+ class Qwen2Attention(nnx.Module):
65
+
66
+ def __init__(self, config: Qwen2Config, dtype: jnp.dtype, rng: nnx.Rngs,
67
+ mesh: Mesh, kv_cache_dtype: str):
68
+ self.hidden_size = config.hidden_size
69
+ self.num_heads = config.num_attention_heads
70
+ self.num_kv_heads = config.num_key_value_heads
71
+ self.rope_theta = config.rope_theta
72
+ self.rope_scaling = getattr(config, "rope_scaling", None)
73
+
74
+ self.head_dim_original = getattr(config, "head_dim",
75
+ self.hidden_size // self.num_heads)
76
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
77
+
78
+ sharding_size = mesh.shape["model"]
79
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
80
+ sharding_size)
81
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
82
+ sharding_size)
83
+
84
+ self.mesh = mesh
85
+
86
+ self.q_proj = nnx.Einsum(
87
+ "TD,DNH->TNH",
88
+ (self.hidden_size, self.num_heads, self.head_dim),
89
+ (self.num_heads, self.head_dim),
90
+ param_dtype=dtype,
91
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
92
+ bias_init=nnx.with_partitioning(init_fn, ("model", None)),
93
+ rngs=rng,
94
+ )
95
+ self.k_proj = nnx.Einsum(
96
+ "TD,DKH->TKH",
97
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
98
+ (self.num_kv_heads, self.head_dim),
99
+ param_dtype=dtype,
100
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
101
+ bias_init=nnx.with_partitioning(init_fn, ("model", None)),
102
+ rngs=rng,
103
+ )
104
+ self.v_proj = nnx.Einsum(
105
+ "TD,DKH->TKH",
106
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
107
+ (self.num_kv_heads, self.head_dim),
108
+ param_dtype=dtype,
109
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
110
+ bias_init=nnx.with_partitioning(init_fn, ("model", None)),
111
+ rngs=rng,
112
+ )
113
+ self.o_proj = nnx.Einsum(
114
+ "TNH,NHD->TD",
115
+ (self.num_heads, self.head_dim, self.hidden_size),
116
+ param_dtype=dtype,
117
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
118
+ rngs=rng,
119
+ )
120
+
121
+ self._q_scale = 1.0
122
+ self._k_scale = 1.0
123
+ self._v_scale = 1.0
124
+ self.kv_cache_quantized_dtype = None
125
+ if kv_cache_dtype != "auto":
126
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
127
+ kv_cache_dtype)
128
+
129
+ def __call__(
130
+ self,
131
+ kv_cache: Optional[jax.Array],
132
+ x: jax.Array,
133
+ attention_metadata: AttentionMetadata,
134
+ ) -> Tuple[jax.Array, jax.Array]:
135
+ md = attention_metadata
136
+ # q: (T, N, H)
137
+ q = self.q_proj(x)
138
+ q = apply_rope(q, md.input_positions, self.head_dim_original,
139
+ self.rope_theta, self.rope_scaling)
140
+
141
+ # k: (T, K, H)
142
+ k = self.k_proj(x)
143
+ k = apply_rope(k, md.input_positions, self.head_dim_original,
144
+ self.rope_theta, self.rope_scaling)
145
+
146
+ # v: (T, K, H)
147
+ v = self.v_proj(x)
148
+ # o: (T, N, H)
149
+ q_scale = k_scale = v_scale = None
150
+ if self.kv_cache_quantized_dtype:
151
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
152
+ # q_scale = self._q_scale
153
+ k_scale = self._k_scale
154
+ v_scale = self._v_scale
155
+ k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
156
+ k_scale, v_scale)
157
+ new_kv_cache, outputs = attention(
158
+ kv_cache,
159
+ q,
160
+ k,
161
+ v,
162
+ attention_metadata,
163
+ self.mesh,
164
+ self.head_dim_original,
165
+ q_scale=q_scale,
166
+ k_scale=k_scale,
167
+ v_scale=v_scale,
168
+ )
169
+ # (T, D)
170
+ o = self.o_proj(outputs)
171
+ return new_kv_cache, o
172
+
173
+
174
+ class Qwen2DecoderLayer(nnx.Module):
175
+
176
+ def __init__(self, config: Qwen2Config, dtype: jnp.dtype, rng: nnx.Rngs,
177
+ mesh: Mesh, kv_cache_dtype: str):
178
+ rms_norm_eps = config.rms_norm_eps
179
+ hidden_size = config.hidden_size
180
+
181
+ self.input_layernorm = nnx.RMSNorm(
182
+ hidden_size,
183
+ epsilon=rms_norm_eps,
184
+ param_dtype=dtype,
185
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
186
+ rngs=rng,
187
+ )
188
+ self.self_attn = Qwen2Attention(config=config,
189
+ dtype=dtype,
190
+ rng=rng,
191
+ mesh=mesh,
192
+ kv_cache_dtype=kv_cache_dtype)
193
+ self.post_attention_layernorm = nnx.RMSNorm(
194
+ hidden_size,
195
+ epsilon=rms_norm_eps,
196
+ param_dtype=dtype,
197
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
198
+ rngs=rng,
199
+ )
200
+ self.mlp = Qwen2MLP(
201
+ config=config,
202
+ dtype=dtype,
203
+ rng=rng,
204
+ )
205
+
206
+ def __call__(
207
+ self,
208
+ kv_cache: jax.Array,
209
+ x: jax.Array,
210
+ attention_metadata: AttentionMetadata,
211
+ ) -> Tuple[jax.Array, jax.Array]:
212
+ hidden_states = self.input_layernorm(x)
213
+ kv_cache, attn_output = self.self_attn(
214
+ kv_cache,
215
+ hidden_states,
216
+ attention_metadata,
217
+ )
218
+ attn_output += x
219
+
220
+ residual = attn_output
221
+ attn_output = self.post_attention_layernorm(attn_output)
222
+ outputs = self.mlp(attn_output)
223
+ outputs = residual + outputs
224
+ return kv_cache, outputs
225
+
226
+
227
+ class Qwen2Model(nnx.Module):
228
+
229
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
230
+ mesh: Mesh) -> None:
231
+ model_config = vllm_config.model_config
232
+ hf_config = model_config.hf_config
233
+ vocab_size = model_config.get_vocab_size()
234
+ dtype = model_config.dtype
235
+ rms_norm_eps = hf_config.rms_norm_eps
236
+ hidden_size = hf_config.hidden_size
237
+
238
+ self.embed = nnx.Embed(
239
+ num_embeddings=vocab_size,
240
+ features=hidden_size,
241
+ param_dtype=dtype,
242
+ embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
243
+ rngs=rng,
244
+ )
245
+ self.layers = [
246
+ Qwen2DecoderLayer(
247
+ config=hf_config,
248
+ dtype=dtype,
249
+ rng=rng,
250
+ mesh=mesh,
251
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
252
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype)
253
+ for _ in range(hf_config.num_hidden_layers)
254
+ ]
255
+ self.norm = nnx.RMSNorm(
256
+ hidden_size,
257
+ epsilon=rms_norm_eps,
258
+ param_dtype=dtype,
259
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
260
+ rngs=rng,
261
+ )
262
+ if model_config.hf_config.tie_word_embeddings:
263
+ self.lm_head = self.embed.embedding
264
+ else:
265
+ self.lm_head = nnx.Param(
266
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
267
+ sharding=(None, "model"),
268
+ )
269
+
270
+ def __call__(
271
+ self,
272
+ kv_caches: List[jax.Array],
273
+ input_ids: Optional[jax.Array],
274
+ attention_metadata: AttentionMetadata,
275
+ inputs_embeds: Optional[jax.Array] = None,
276
+ ) -> Tuple[List[jax.Array], jax.Array]:
277
+ if inputs_embeds is not None:
278
+ x = inputs_embeds
279
+ else:
280
+ x = self.embed(input_ids)
281
+ for i, layer in enumerate(self.layers):
282
+ kv_cache = kv_caches[i]
283
+ kv_cache, x = layer(
284
+ kv_cache,
285
+ x,
286
+ attention_metadata,
287
+ )
288
+ kv_caches[i] = kv_cache
289
+ x = self.norm(x)
290
+ return kv_caches, x
291
+
292
+
293
+ class Qwen2ForCausalLM(nnx.Module):
294
+
295
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
296
+ mesh: Mesh) -> None:
297
+ self.vllm_config = vllm_config
298
+ self.rng = nnx.Rngs(rng_key)
299
+ self.mesh = mesh
300
+
301
+ self.model = Qwen2Model(
302
+ vllm_config=vllm_config,
303
+ rng=self.rng,
304
+ mesh=mesh,
305
+ )
306
+
307
+ def __call__(
308
+ self,
309
+ kv_caches: List[jax.Array],
310
+ input_ids: Optional[jax.Array],
311
+ attention_metadata: AttentionMetadata,
312
+ inputs_embeds: Optional[jax.Array] = None,
313
+ *args,
314
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
315
+ kv_caches, x = self.model(
316
+ kv_caches,
317
+ input_ids,
318
+ attention_metadata,
319
+ inputs_embeds,
320
+ )
321
+ return kv_caches, x, []
322
+
323
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
324
+ if self.vllm_config.model_config.hf_config.tie_word_embeddings:
325
+ logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
326
+ else:
327
+ logits = jnp.dot(hidden_states, self.model.lm_head.value)
328
+ return logits
329
+
330
+ def load_weights(self, rng_key: jax.Array):
331
+ # NOTE: Since we are using nnx.eval_shape to init the model,
332
+ # we have to pass dynamic arrays here for __call__'s usage.
333
+ self.rng = nnx.Rngs(rng_key)
334
+
335
+ # Key: path to a HF layer weight
336
+ # Value: path to a nnx layer weight
337
+ mappings = {
338
+ "model.embed_tokens": "model.embed.embedding",
339
+ "model.layers.*.input_layernorm":
340
+ "model.layers.*.input_layernorm.scale",
341
+ "model.layers.*.mlp.down_proj":
342
+ "model.layers.*.mlp.down_proj.kernel",
343
+ "model.layers.*.mlp.gate_proj":
344
+ "model.layers.*.mlp.gate_proj.kernel",
345
+ "model.layers.*.mlp.up_proj": "model.layers.*.mlp.up_proj.kernel",
346
+ "model.layers.*.post_attention_layernorm":
347
+ "model.layers.*.post_attention_layernorm.scale",
348
+ "model.layers.*.self_attn.k_proj":
349
+ "model.layers.*.self_attn.k_proj.kernel",
350
+ "model.layers.*.self_attn.o_proj":
351
+ "model.layers.*.self_attn.o_proj.kernel",
352
+ "model.layers.*.self_attn.q_proj":
353
+ "model.layers.*.self_attn.q_proj.kernel",
354
+ "model.layers.*.self_attn.v_proj":
355
+ "model.layers.*.self_attn.v_proj.kernel",
356
+ "model.layers.*.self_attn.q_proj.bias":
357
+ "model.layers.*.self_attn.q_proj.bias",
358
+ "model.layers.*.self_attn.k_proj.bias":
359
+ "model.layers.*.self_attn.k_proj.bias",
360
+ "model.layers.*.self_attn.v_proj.bias":
361
+ "model.layers.*.self_attn.v_proj.bias",
362
+ "model.norm": "model.norm.scale",
363
+ }
364
+
365
+ # Add lm_head mapping only if it's not tied to embeddings
366
+ if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
367
+ mappings.update({
368
+ "lm_head": "model.lm_head",
369
+ })
370
+
371
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
372
+ load_hf_weights(vllm_config=self.vllm_config,
373
+ model=self,
374
+ metadata_map=metadata_map,
375
+ mesh=self.mesh)