tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,302 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.sharding import Mesh
7
+ from transformers import Qwen3Config
8
+ from vllm.config import VllmConfig
9
+
10
+ from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_interface import attention
12
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
+ from tpu_inference.layers.jax.rope_interface import apply_rope
14
+ from tpu_inference.logger import init_logger
15
+ from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
16
+ from tpu_inference.models.jax.qwen2 import Qwen2MLP as Qwen3MLP
17
+ from tpu_inference.models.jax.qwen2 import Qwen2Model
18
+ from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
19
+ load_hf_weights)
20
+
21
+ logger = init_logger(__name__)
22
+
23
+ init_fn = nnx.initializers.uniform()
24
+
25
+
26
+ class Qwen3Attention(nnx.Module):
27
+
28
+ def __init__(self, config: Qwen3Config, dtype: jnp.dtype, rng: nnx.Rngs,
29
+ mesh: Mesh, kv_cache_dtype: str):
30
+ self.hidden_size = config.hidden_size
31
+ self.num_heads = config.num_attention_heads
32
+ self.num_kv_heads = config.num_key_value_heads
33
+ self.rope_theta = config.rope_theta
34
+ self.rope_scaling = getattr(config, "rope_scaling", None)
35
+ self.rms_norm_eps = config.rms_norm_eps
36
+
37
+ self.head_dim_original = getattr(config, "head_dim",
38
+ self.hidden_size // self.num_heads)
39
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
40
+
41
+ sharding_size = mesh.shape["model"]
42
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
43
+ sharding_size)
44
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
45
+ sharding_size)
46
+
47
+ self.mesh = mesh
48
+
49
+ self.q_proj = nnx.Einsum(
50
+ "TD,DNH->TNH",
51
+ (self.hidden_size, self.num_heads, self.head_dim),
52
+ param_dtype=dtype,
53
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
54
+ rngs=rng,
55
+ )
56
+ self.q_norm = nnx.RMSNorm(
57
+ self.head_dim,
58
+ epsilon=self.rms_norm_eps,
59
+ param_dtype=dtype,
60
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
61
+ rngs=rng,
62
+ )
63
+ self.k_proj = nnx.Einsum(
64
+ "TD,DKH->TKH",
65
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
66
+ param_dtype=dtype,
67
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
68
+ rngs=rng,
69
+ )
70
+ self.k_norm = nnx.RMSNorm(
71
+ self.head_dim,
72
+ epsilon=self.rms_norm_eps,
73
+ param_dtype=dtype,
74
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
75
+ rngs=rng,
76
+ )
77
+ self.v_proj = nnx.Einsum(
78
+ "TD,DKH->TKH",
79
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
80
+ param_dtype=dtype,
81
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
82
+ rngs=rng,
83
+ )
84
+ self.o_proj = nnx.Einsum(
85
+ "TNH,NHD->TD",
86
+ (self.num_heads, self.head_dim, self.hidden_size),
87
+ param_dtype=dtype,
88
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
89
+ rngs=rng,
90
+ )
91
+
92
+ self._q_scale = 1.0
93
+ self._k_scale = 1.0
94
+ self._v_scale = 1.0
95
+ self.kv_cache_quantized_dtype = None
96
+ if kv_cache_dtype != "auto":
97
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
98
+ kv_cache_dtype)
99
+
100
+ def __call__(
101
+ self,
102
+ kv_cache: Optional[jax.Array],
103
+ x: jax.Array,
104
+ attention_metadata: AttentionMetadata,
105
+ ) -> Tuple[jax.Array, jax.Array]:
106
+ md = attention_metadata
107
+ # q: (T, N, H)
108
+ q = self.q_proj(x)
109
+ q = self.q_norm(q)
110
+ q = apply_rope(q, md.input_positions, self.head_dim_original,
111
+ self.rope_theta, self.rope_scaling)
112
+
113
+ # k: (T, K, H)
114
+ k = self.k_proj(x)
115
+ k = self.k_norm(k)
116
+ k = apply_rope(k, md.input_positions, self.head_dim_original,
117
+ self.rope_theta, self.rope_scaling)
118
+
119
+ # v: (T, K, H)
120
+ v = self.v_proj(x)
121
+ # o: (T, N, H)
122
+ q_scale = k_scale = v_scale = None
123
+ if self.kv_cache_quantized_dtype:
124
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
125
+ # q_scale = self._q_scale
126
+ k_scale = self._k_scale
127
+ v_scale = self._v_scale
128
+ k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
129
+ k_scale, v_scale)
130
+ new_kv_cache, outputs = attention(
131
+ kv_cache,
132
+ q,
133
+ k,
134
+ v,
135
+ attention_metadata,
136
+ self.mesh,
137
+ self.head_dim_original,
138
+ q_scale=q_scale,
139
+ k_scale=k_scale,
140
+ v_scale=v_scale,
141
+ )
142
+ # (T, D)
143
+ o = self.o_proj(outputs)
144
+ return new_kv_cache, o
145
+
146
+
147
+ class Qwen3DecoderLayer(Qwen2DecoderLayer):
148
+
149
+ def __init__(self, config: Qwen3Config, dtype: jnp.dtype, rng: nnx.Rngs,
150
+ mesh: Mesh, kv_cache_dtype: str):
151
+ rms_norm_eps = config.rms_norm_eps
152
+ hidden_size = config.hidden_size
153
+
154
+ self.input_layernorm = nnx.RMSNorm(
155
+ hidden_size,
156
+ epsilon=rms_norm_eps,
157
+ param_dtype=dtype,
158
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
159
+ rngs=rng,
160
+ )
161
+ self.self_attn = Qwen3Attention(config=config,
162
+ dtype=dtype,
163
+ rng=rng,
164
+ mesh=mesh,
165
+ kv_cache_dtype=kv_cache_dtype)
166
+ self.post_attention_layernorm = nnx.RMSNorm(
167
+ hidden_size,
168
+ epsilon=rms_norm_eps,
169
+ param_dtype=dtype,
170
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
171
+ rngs=rng,
172
+ )
173
+ self.mlp = Qwen3MLP(
174
+ config=config,
175
+ dtype=dtype,
176
+ rng=rng,
177
+ )
178
+
179
+
180
+ class Qwen3Model(Qwen2Model):
181
+
182
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
183
+ mesh: Mesh) -> None:
184
+ model_config = vllm_config.model_config
185
+ hf_config = model_config.hf_config
186
+ vocab_size = model_config.get_vocab_size()
187
+ dtype = model_config.dtype
188
+ rms_norm_eps = hf_config.rms_norm_eps
189
+ hidden_size = hf_config.hidden_size
190
+
191
+ self.embed = nnx.Embed(
192
+ num_embeddings=vocab_size,
193
+ features=hidden_size,
194
+ param_dtype=dtype,
195
+ embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
196
+ rngs=rng,
197
+ )
198
+ self.layers = [
199
+ Qwen3DecoderLayer(
200
+ config=hf_config,
201
+ dtype=dtype,
202
+ rng=rng,
203
+ mesh=mesh,
204
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
205
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype)
206
+ for _ in range(hf_config.num_hidden_layers)
207
+ ]
208
+ self.norm = nnx.RMSNorm(
209
+ hidden_size,
210
+ epsilon=rms_norm_eps,
211
+ param_dtype=dtype,
212
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
213
+ rngs=rng,
214
+ )
215
+ if model_config.hf_config.tie_word_embeddings:
216
+ self.lm_head = self.embed.embedding
217
+ else:
218
+ self.lm_head = nnx.Param(
219
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
220
+ sharding=(None, "model"),
221
+ )
222
+
223
+
224
+ class Qwen3ForCausalLM(nnx.Module):
225
+
226
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
227
+ mesh: Mesh) -> None:
228
+ self.vllm_config = vllm_config
229
+ self.rng = nnx.Rngs(rng_key)
230
+ self.mesh = mesh
231
+
232
+ self.model = Qwen3Model(
233
+ vllm_config=vllm_config,
234
+ rng=self.rng,
235
+ mesh=mesh,
236
+ )
237
+
238
+ def __call__(
239
+ self,
240
+ kv_caches: List[jax.Array],
241
+ input_ids: jax.Array,
242
+ attention_metadata: AttentionMetadata,
243
+ *args,
244
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
245
+ kv_caches, x = self.model(
246
+ kv_caches,
247
+ input_ids,
248
+ attention_metadata,
249
+ )
250
+ return kv_caches, x, []
251
+
252
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
253
+ if self.vllm_config.model_config.hf_config.tie_word_embeddings:
254
+ logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
255
+ else:
256
+ logits = jnp.dot(hidden_states, self.model.lm_head.value)
257
+ return logits
258
+
259
+ def load_weights(self, rng_key: jax.Array):
260
+ # NOTE: Since we are using nnx.eval_shape to init the model,
261
+ # we have to pass dynamic arrays here for __call__'s usage.
262
+ self.rng = nnx.Rngs(rng_key)
263
+
264
+ # Key: path to a HF layer weight
265
+ # Value: path to a nnx layer weight
266
+ mappings = {
267
+ "model.embed_tokens": "model.embed.embedding",
268
+ "model.layers.*.input_layernorm":
269
+ "model.layers.*.input_layernorm.scale",
270
+ "model.layers.*.mlp.down_proj":
271
+ "model.layers.*.mlp.down_proj.kernel",
272
+ "model.layers.*.mlp.gate_proj":
273
+ "model.layers.*.mlp.gate_proj.kernel",
274
+ "model.layers.*.mlp.up_proj": "model.layers.*.mlp.up_proj.kernel",
275
+ "model.layers.*.post_attention_layernorm":
276
+ "model.layers.*.post_attention_layernorm.scale",
277
+ "model.layers.*.self_attn.k_norm":
278
+ "model.layers.*.self_attn.k_norm.scale",
279
+ "model.layers.*.self_attn.k_proj":
280
+ "model.layers.*.self_attn.k_proj.kernel",
281
+ "model.layers.*.self_attn.o_proj":
282
+ "model.layers.*.self_attn.o_proj.kernel",
283
+ "model.layers.*.self_attn.q_norm":
284
+ "model.layers.*.self_attn.q_norm.scale",
285
+ "model.layers.*.self_attn.q_proj":
286
+ "model.layers.*.self_attn.q_proj.kernel",
287
+ "model.layers.*.self_attn.v_proj":
288
+ "model.layers.*.self_attn.v_proj.kernel",
289
+ "model.norm": "model.norm.scale",
290
+ }
291
+
292
+ # Add lm_head mapping only if it's not tied to embeddings
293
+ if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
294
+ mappings.update({
295
+ "lm_head": "model.lm_head",
296
+ })
297
+
298
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
299
+ load_hf_weights(vllm_config=self.vllm_config,
300
+ model=self,
301
+ metadata_map=metadata_map,
302
+ mesh=self.mesh)
File without changes
@@ -0,0 +1,96 @@
1
+ import glob
2
+ import hashlib
3
+ import os
4
+ import shutil
5
+ import subprocess
6
+ from typing import List, Optional
7
+
8
+ import filelock
9
+ import huggingface_hub.constants
10
+ from huggingface_hub import HfFileSystem, snapshot_download
11
+ from tqdm.auto import tqdm
12
+
13
+ from tpu_inference.logger import init_logger
14
+
15
+ logger = init_logger(__name__)
16
+ # Do not set the HuggingFace token here, it should be set via the env `HF_TOKEN`.
17
+ hfs = HfFileSystem()
18
+
19
+ LOCK_DIR = "/tmp/lock"
20
+
21
+ ##### Local file utils #####
22
+
23
+
24
+ def run_cmd(cmd: str, *args, **kwargs) -> subprocess.CompletedProcess:
25
+ return subprocess.run(cmd.split(), *args, **kwargs)
26
+
27
+
28
+ def delete_file(path: str) -> None:
29
+ if os.path.isfile(path):
30
+ os.remove(path)
31
+ else:
32
+ logger.error(f"Trying to delete non-existing file: {path}")
33
+
34
+
35
+ def list_files(dir: str, pattern: str = "*") -> List[str]:
36
+ files = glob.glob(os.path.join(dir, pattern))
37
+ return files
38
+
39
+
40
+ def get_lock(model_name_or_path: str):
41
+ lock_dir = LOCK_DIR
42
+ model_name_or_path = str(model_name_or_path)
43
+ os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
44
+ model_name = model_name_or_path.replace("/", "-")
45
+ hash_name = hashlib.sha256(model_name.encode()).hexdigest()
46
+ # add hash to avoid conflict with old users' lock files
47
+ lock_file_name = hash_name + model_name + ".lock"
48
+ # mode 0o666 is required for the filelock to be shared across users
49
+ lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
50
+ mode=0o666)
51
+ return lock
52
+
53
+
54
+ def get_free_disk_size(path: str = "/") -> int:
55
+ free_bytes = shutil.disk_usage(path)[2]
56
+ return free_bytes
57
+
58
+
59
+ ##### HuggingFace file utils #####
60
+
61
+
62
+ def is_hf_repo(repo_id: str) -> bool:
63
+ return hfs.exists(repo_id)
64
+
65
+
66
+ def list_hf_repo(repo_id: str, pattern: str = "**") -> List[str]:
67
+ repo_files = hfs.glob(os.path.join(repo_id, pattern))
68
+ return repo_files
69
+
70
+
71
+ def get_hf_model_weights_size(repo_id: str, weights_format: str) -> int:
72
+ weights_paths = list_hf_repo(repo_id, weights_format)
73
+ weights_size = 0
74
+ for weights_path in weights_paths:
75
+ weights_size += int(hfs.info(weights_path)["size"])
76
+ return weights_size
77
+
78
+
79
+ class DisabledTqdm(tqdm):
80
+
81
+ def __init__(self, *args, **kwargs):
82
+ super().__init__(*args, **kwargs, disable=True)
83
+
84
+
85
+ def download_model_weights_from_hf(model_path: str, cache_dir: Optional[str],
86
+ weights_format: str) -> str:
87
+ with get_lock(model_path):
88
+ local_dir = snapshot_download(
89
+ model_path,
90
+ cache_dir=cache_dir, # can be specified by HF_HOME or HF_HUB_CACHE
91
+ allow_patterns=weights_format,
92
+ tqdm_class=DisabledTqdm,
93
+ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
94
+ )
95
+ local_files = list_files(local_dir, weights_format)
96
+ return local_files
@@ -0,0 +1,163 @@
1
+ from typing import Union
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from typing_extensions import TypeAlias
6
+ from vllm.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+ NestedTensors: TypeAlias = Union[list["NestedTensors"], list["jax.Array"],
11
+ "jax.Array", tuple["jax.Array", ...]]
12
+ """
13
+ Uses a list instead of a tensor if the dimensions of each element do not match.
14
+ """
15
+
16
+ MultiModalEmbeddings = Union[list[jax.Array], jax.Array, tuple[jax.Array, ...]]
17
+ """
18
+ The output embeddings must be one of the following formats:
19
+
20
+ - A list or tuple of 2D tensors, where each tensor corresponds to
21
+ each input multimodal data item (e.g, image).
22
+ - A single 3D tensor, with the batch dimension grouping the 2D tensors.
23
+ """
24
+
25
+
26
+ def sanity_check_mm_encoder_outputs(
27
+ mm_embeddings: MultiModalEmbeddings,
28
+ expected_num_items: int,
29
+ ) -> None:
30
+ """
31
+ Perform sanity checks for the result of
32
+ [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
33
+ """
34
+ assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
35
+ "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
36
+ f"or a single 3D tensor, but got {type(mm_embeddings)} "
37
+ "instead. This is most likely due to incorrect implementation "
38
+ "of the model's `get_multimodal_embeddings` method.")
39
+
40
+ assert len(mm_embeddings) == expected_num_items, (
41
+ "Expected number of multimodal embeddings to match number of "
42
+ f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
43
+ "instead. This is most likely due to incorrect implementation "
44
+ "of the model's `get_multimodal_embeddings` method.")
45
+
46
+ assert all(e.ndim == 2 for e in mm_embeddings), (
47
+ "Expected multimodal embeddings to be a sequence of 2D tensors, "
48
+ f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
49
+ "instead. This is most likely due to incorrect implementation "
50
+ "of the model's `get_multimodal_embeddings` method.")
51
+
52
+
53
+ def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
54
+ """
55
+ Recursively flattens and concatenates NestedTensors on all but the last
56
+ dimension.
57
+ """
58
+
59
+ if isinstance(embeddings, jax.Array):
60
+ return embeddings.reshape(-1, embeddings.shape[-1])
61
+
62
+ return jnp.concatenate([flatten_embeddings(t) for t in embeddings], axis=0)
63
+
64
+
65
+ def _embedding_count_expression(embeddings: NestedTensors) -> str:
66
+ """
67
+ Constructs a debugging representation of the number of embeddings in the
68
+ NestedTensors.
69
+ """
70
+
71
+ if isinstance(embeddings, jax.Array):
72
+ return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
73
+
74
+ return " + ".join(
75
+ _embedding_count_expression(inner) for inner in embeddings)
76
+
77
+
78
+ def _merge_multimodal_embeddings(
79
+ inputs_embeds: jax.Array,
80
+ is_multimodal: jax.Array,
81
+ multimodal_embeddings: jax.Array,
82
+ ) -> jax.Array:
83
+ """
84
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
85
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
86
+ ``input_ids``.
87
+ This returns a new array with the updated values.
88
+ Note:
89
+ This returns a new array with the updated values.
90
+ """
91
+ # The check for matching number of tokens is removed as it is not
92
+ # JIT-compatible. If the shapes mismatch, JAX will raise an error
93
+ # during execution anyway. The user-friendly error message is
94
+ # sacrificed for JIT compatibility.
95
+
96
+ # JIT-compatible implementation using jnp.where to avoid
97
+ # NonConcreteBooleanIndexError.
98
+ # Create a dummy row to handle indices for non-multimodal tokens.
99
+ # The content of the dummy row does not matter as it will be masked out.
100
+ dummy_row = jnp.zeros_like(multimodal_embeddings[0:1])
101
+
102
+ # Prepend the dummy row to the flattened embeddings.
103
+ flattened_padded = jnp.concatenate([dummy_row, multimodal_embeddings],
104
+ axis=0)
105
+
106
+ # Create gather indices. For each token in the input sequence, this gives
107
+ # the index into `flattened_padded`.
108
+ # For non-multimodal tokens, the index will be 0 (pointing to the dummy
109
+ # row). For the k-th multimodal token, the index will be k.
110
+ gather_indices = jnp.cumsum(is_multimodal)
111
+
112
+ # Gather the embeddings to be placed.
113
+ update_values = flattened_padded[gather_indices]
114
+
115
+ # Use jnp.where to select between original and new embeddings.
116
+ condition = jnp.expand_dims(is_multimodal, axis=-1)
117
+ return jnp.where(condition, update_values, inputs_embeds)
118
+
119
+
120
+ def merge_multimodal_embeddings(
121
+ input_ids: jax.Array,
122
+ inputs_embeds: jax.Array,
123
+ multimodal_embeddings: jax.Array,
124
+ placeholder_token_id: Union[int, list[int]],
125
+ ) -> jax.Array:
126
+ """
127
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
128
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
129
+ ``input_ids``.
130
+
131
+ ``placeholder_token_id`` can be a list of token ids (e.g, token ids
132
+ of img_start, img_break, and img_end tokens) when needed: This means
133
+ the order of these tokens in the ``input_ids`` MUST MATCH the order of
134
+ their embeddings in ``multimodal_embeddings`` since we need to
135
+ slice-merge instead of individually scattering.
136
+
137
+ For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
138
+ - T is text token
139
+ - S is image start token
140
+ - I is image embedding token
141
+ - B is image break token
142
+ - E is image end token.
143
+
144
+ Then the image embeddings (that correspond to I's) from vision encoder
145
+ must be padded with embeddings of S, B, and E in the same order of
146
+ input_ids for a correct embedding merge.
147
+
148
+ This returns a new array with the updated values.
149
+ """
150
+ if isinstance(placeholder_token_id, list):
151
+ placeholder_token_id = jnp.array(placeholder_token_id)
152
+
153
+ return _merge_multimodal_embeddings(
154
+ inputs_embeds,
155
+ jnp.isin(input_ids, placeholder_token_id),
156
+ multimodal_embeddings,
157
+ )
158
+
159
+ return _merge_multimodal_embeddings(
160
+ inputs_embeds,
161
+ (input_ids == placeholder_token_id),
162
+ multimodal_embeddings,
163
+ )
@@ -0,0 +1,5 @@
1
+ qwix:
2
+ rules:
3
+ # NOTE: each entry corresponds to a qwix.QuantizationRule
4
+ - module_path: '.*'
5
+ weight_qtype: 'float8_e4m3fn'
@@ -0,0 +1,6 @@
1
+ qwix:
2
+ rules:
3
+ # NOTE: each entry corresponds to a qwix.QuantizationRule
4
+ - module_path: '.*'
5
+ weight_qtype: 'float8_e4m3fn'
6
+ act_qtype: 'float8_e4m3fn'
@@ -0,0 +1,5 @@
1
+ qwix:
2
+ rules:
3
+ # NOTE: each entry corresponds to a qwix.QuantizationRule
4
+ - module_path: '.*'
5
+ weight_qtype: 'int8'
@@ -0,0 +1,6 @@
1
+ qwix:
2
+ rules:
3
+ # NOTE: each entry corresponds to a qwix.QuantizationRule
4
+ - module_path: '.*'
5
+ weight_qtype: 'int8'
6
+ act_qtype: 'int8'