tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from flax import nnx
|
|
6
|
+
from jax.sharding import Mesh
|
|
7
|
+
from transformers import Phi3Config, modeling_flax_utils
|
|
8
|
+
from vllm.config import VllmConfig
|
|
9
|
+
|
|
10
|
+
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
+
from tpu_inference.layers.jax.attention_interface import attention
|
|
13
|
+
from tpu_inference.layers.jax.rope_interface import apply_longrope, apply_rope
|
|
14
|
+
from tpu_inference.logger import init_logger
|
|
15
|
+
from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
|
|
16
|
+
load_hf_weights)
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
20
|
+
init_fn = nnx.initializers.uniform()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Phi3MLP(nnx.Module):
|
|
24
|
+
|
|
25
|
+
def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs):
|
|
26
|
+
hidden_size = config.hidden_size
|
|
27
|
+
intermediate_size = config.intermediate_size
|
|
28
|
+
act = config.hidden_act
|
|
29
|
+
|
|
30
|
+
self.gate_up_proj = nnx.Linear(
|
|
31
|
+
hidden_size,
|
|
32
|
+
2 * intermediate_size,
|
|
33
|
+
use_bias=False,
|
|
34
|
+
param_dtype=dtype,
|
|
35
|
+
kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
|
|
36
|
+
rngs=rng,
|
|
37
|
+
)
|
|
38
|
+
self.down_proj = nnx.Linear(
|
|
39
|
+
intermediate_size,
|
|
40
|
+
hidden_size,
|
|
41
|
+
use_bias=False,
|
|
42
|
+
param_dtype=dtype,
|
|
43
|
+
kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
|
|
44
|
+
rngs=rng,
|
|
45
|
+
)
|
|
46
|
+
self.act_fn = modeling_flax_utils.ACT2FN[act]
|
|
47
|
+
|
|
48
|
+
def __call__(self, x: jax.Array) -> jax.Array:
|
|
49
|
+
gate_up = self.gate_up_proj(x)
|
|
50
|
+
gate, up = jnp.split(gate_up, 2, axis=-1)
|
|
51
|
+
fuse = up * self.act_fn(gate)
|
|
52
|
+
result = self.down_proj(fuse)
|
|
53
|
+
return result
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Phi3Attention(nnx.Module):
|
|
57
|
+
|
|
58
|
+
def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
|
|
59
|
+
mesh: Mesh, kv_cache_dtype: str):
|
|
60
|
+
self.hidden_size = config.hidden_size
|
|
61
|
+
self.num_heads = config.num_attention_heads
|
|
62
|
+
self.num_kv_heads = config.num_key_value_heads
|
|
63
|
+
self.rope_theta = config.rope_theta
|
|
64
|
+
self.rope_scaling = config.rope_scaling
|
|
65
|
+
self.original_max_position_embeddings = config.original_max_position_embeddings
|
|
66
|
+
self.max_position_embeddings = config.max_position_embeddings
|
|
67
|
+
|
|
68
|
+
self.head_dim_original = getattr(config, "head_dim",
|
|
69
|
+
self.hidden_size // self.num_heads)
|
|
70
|
+
self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
|
|
71
|
+
|
|
72
|
+
sharding_size = mesh.shape["model"]
|
|
73
|
+
self.num_heads = utils.get_padded_num_heads(self.num_heads,
|
|
74
|
+
sharding_size)
|
|
75
|
+
self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
|
|
76
|
+
sharding_size)
|
|
77
|
+
|
|
78
|
+
self.mesh = mesh
|
|
79
|
+
|
|
80
|
+
self.qkv_proj = nnx.Einsum(
|
|
81
|
+
"TD,DNH->TNH",
|
|
82
|
+
(self.hidden_size, self.num_heads + self.num_kv_heads * 2,
|
|
83
|
+
self.head_dim),
|
|
84
|
+
param_dtype=dtype,
|
|
85
|
+
kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
|
|
86
|
+
rngs=rng,
|
|
87
|
+
)
|
|
88
|
+
self.o_proj = nnx.Einsum(
|
|
89
|
+
"TNH,NHD->TD",
|
|
90
|
+
(self.num_heads, self.head_dim, self.hidden_size),
|
|
91
|
+
param_dtype=dtype,
|
|
92
|
+
kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
|
|
93
|
+
rngs=rng,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self._q_scale = 1.0
|
|
97
|
+
self._k_scale = 1.0
|
|
98
|
+
self._v_scale = 1.0
|
|
99
|
+
self.kv_cache_quantized_dtype = None
|
|
100
|
+
if kv_cache_dtype != "auto":
|
|
101
|
+
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
102
|
+
kv_cache_dtype)
|
|
103
|
+
|
|
104
|
+
def __call__(
|
|
105
|
+
self,
|
|
106
|
+
kv_cache: Optional[jax.Array],
|
|
107
|
+
x: jax.Array,
|
|
108
|
+
attention_metadata: AttentionMetadata,
|
|
109
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
110
|
+
md = attention_metadata
|
|
111
|
+
# qkv: (T, N + K * 2, H)
|
|
112
|
+
qkv = self.qkv_proj(x)
|
|
113
|
+
q, k, v = jnp.split(
|
|
114
|
+
qkv, [self.num_heads, self.num_heads + self.num_kv_heads], axis=1)
|
|
115
|
+
if self.rope_scaling:
|
|
116
|
+
q = apply_longrope(q, md.input_positions, self.head_dim_original,
|
|
117
|
+
self.rope_scaling,
|
|
118
|
+
self.original_max_position_embeddings,
|
|
119
|
+
self.max_position_embeddings, self.rope_theta)
|
|
120
|
+
k = apply_longrope(k, md.input_positions, self.head_dim_original,
|
|
121
|
+
self.rope_scaling,
|
|
122
|
+
self.original_max_position_embeddings,
|
|
123
|
+
self.max_position_embeddings, self.rope_theta)
|
|
124
|
+
else:
|
|
125
|
+
q = apply_rope(q, md.input_positions, self.head_dim_original,
|
|
126
|
+
self.rope_theta, self.rope_scaling)
|
|
127
|
+
k = apply_rope(k, md.input_positions, self.head_dim_original,
|
|
128
|
+
self.rope_theta, self.rope_scaling)
|
|
129
|
+
# o: (T, N, H)
|
|
130
|
+
q_scale = k_scale = v_scale = None
|
|
131
|
+
if self.kv_cache_quantized_dtype:
|
|
132
|
+
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
133
|
+
# q_scale = self._q_scale
|
|
134
|
+
k_scale = self._k_scale
|
|
135
|
+
v_scale = self._v_scale
|
|
136
|
+
k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
|
|
137
|
+
k_scale, v_scale)
|
|
138
|
+
new_kv_cache, outputs = attention(
|
|
139
|
+
kv_cache,
|
|
140
|
+
q,
|
|
141
|
+
k,
|
|
142
|
+
v,
|
|
143
|
+
attention_metadata,
|
|
144
|
+
self.mesh,
|
|
145
|
+
self.head_dim_original,
|
|
146
|
+
q_scale=q_scale,
|
|
147
|
+
k_scale=k_scale,
|
|
148
|
+
v_scale=v_scale,
|
|
149
|
+
)
|
|
150
|
+
# (T, D)
|
|
151
|
+
o = self.o_proj(outputs)
|
|
152
|
+
return new_kv_cache, o
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class Phi3DecoderLayer(nnx.Module):
|
|
156
|
+
|
|
157
|
+
def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
|
|
158
|
+
mesh: Mesh, kv_cache_dtype: str):
|
|
159
|
+
rms_norm_eps = config.rms_norm_eps
|
|
160
|
+
hidden_size = config.hidden_size
|
|
161
|
+
|
|
162
|
+
self.input_layernorm = nnx.RMSNorm(
|
|
163
|
+
hidden_size,
|
|
164
|
+
epsilon=rms_norm_eps,
|
|
165
|
+
param_dtype=dtype,
|
|
166
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
167
|
+
rngs=rng,
|
|
168
|
+
)
|
|
169
|
+
self.self_attn = Phi3Attention(config=config,
|
|
170
|
+
dtype=dtype,
|
|
171
|
+
rng=rng,
|
|
172
|
+
mesh=mesh,
|
|
173
|
+
kv_cache_dtype=kv_cache_dtype)
|
|
174
|
+
self.post_attention_layernorm = nnx.RMSNorm(
|
|
175
|
+
hidden_size,
|
|
176
|
+
epsilon=rms_norm_eps,
|
|
177
|
+
param_dtype=dtype,
|
|
178
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
179
|
+
rngs=rng,
|
|
180
|
+
)
|
|
181
|
+
self.mlp = Phi3MLP(
|
|
182
|
+
config=config,
|
|
183
|
+
dtype=dtype,
|
|
184
|
+
rng=rng,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def __call__(
|
|
188
|
+
self,
|
|
189
|
+
kv_cache: jax.Array,
|
|
190
|
+
x: jax.Array,
|
|
191
|
+
attention_metadata: AttentionMetadata,
|
|
192
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
193
|
+
hidden_states = self.input_layernorm(x)
|
|
194
|
+
kv_cache, attn_output = self.self_attn(
|
|
195
|
+
kv_cache,
|
|
196
|
+
hidden_states,
|
|
197
|
+
attention_metadata,
|
|
198
|
+
)
|
|
199
|
+
attn_output += x
|
|
200
|
+
|
|
201
|
+
residual = attn_output
|
|
202
|
+
attn_output = self.post_attention_layernorm(attn_output)
|
|
203
|
+
outputs = self.mlp(attn_output)
|
|
204
|
+
outputs = residual + outputs
|
|
205
|
+
return kv_cache, outputs
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class Phi3Model(nnx.Module):
|
|
209
|
+
|
|
210
|
+
def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
|
|
211
|
+
mesh: Mesh) -> None:
|
|
212
|
+
model_config = vllm_config.model_config
|
|
213
|
+
hf_config = model_config.hf_config
|
|
214
|
+
vocab_size = model_config.get_vocab_size()
|
|
215
|
+
dtype = model_config.dtype
|
|
216
|
+
rms_norm_eps = hf_config.rms_norm_eps
|
|
217
|
+
hidden_size = hf_config.hidden_size
|
|
218
|
+
|
|
219
|
+
self.embed = nnx.Embed(
|
|
220
|
+
num_embeddings=vocab_size,
|
|
221
|
+
features=hidden_size,
|
|
222
|
+
param_dtype=dtype,
|
|
223
|
+
embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
|
|
224
|
+
rngs=rng,
|
|
225
|
+
)
|
|
226
|
+
self.layers = [
|
|
227
|
+
Phi3DecoderLayer(
|
|
228
|
+
config=hf_config,
|
|
229
|
+
dtype=dtype,
|
|
230
|
+
rng=rng,
|
|
231
|
+
mesh=mesh,
|
|
232
|
+
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
233
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
|
|
234
|
+
for _ in range(hf_config.num_hidden_layers)
|
|
235
|
+
]
|
|
236
|
+
self.norm = nnx.RMSNorm(
|
|
237
|
+
hidden_size,
|
|
238
|
+
epsilon=rms_norm_eps,
|
|
239
|
+
param_dtype=dtype,
|
|
240
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
241
|
+
rngs=rng,
|
|
242
|
+
)
|
|
243
|
+
if model_config.hf_config.tie_word_embeddings:
|
|
244
|
+
self.lm_head = self.embed.embedding
|
|
245
|
+
else:
|
|
246
|
+
self.lm_head = nnx.Param(
|
|
247
|
+
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
|
|
248
|
+
sharding=(None, "model"),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def __call__(
|
|
252
|
+
self,
|
|
253
|
+
kv_caches: List[jax.Array],
|
|
254
|
+
input_ids: jax.Array,
|
|
255
|
+
attention_metadata: AttentionMetadata,
|
|
256
|
+
) -> Tuple[List[jax.Array], jax.Array]:
|
|
257
|
+
x = self.embed(input_ids)
|
|
258
|
+
for i, layer in enumerate(self.layers):
|
|
259
|
+
kv_cache = kv_caches[i]
|
|
260
|
+
kv_cache, x = layer(
|
|
261
|
+
kv_cache,
|
|
262
|
+
x,
|
|
263
|
+
attention_metadata,
|
|
264
|
+
)
|
|
265
|
+
kv_caches[i] = kv_cache
|
|
266
|
+
x = self.norm(x)
|
|
267
|
+
return kv_caches, x
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class Phi3ForCausalLM(nnx.Module):
|
|
271
|
+
|
|
272
|
+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
|
|
273
|
+
mesh: Mesh) -> None:
|
|
274
|
+
self.vllm_config = vllm_config
|
|
275
|
+
self.rng = nnx.Rngs(rng_key)
|
|
276
|
+
self.mesh = mesh
|
|
277
|
+
|
|
278
|
+
self.model = Phi3Model(
|
|
279
|
+
vllm_config=vllm_config,
|
|
280
|
+
rng=self.rng,
|
|
281
|
+
mesh=mesh,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def __call__(
|
|
285
|
+
self,
|
|
286
|
+
kv_caches: List[jax.Array],
|
|
287
|
+
input_ids: jax.Array,
|
|
288
|
+
attention_metadata: AttentionMetadata,
|
|
289
|
+
*args,
|
|
290
|
+
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
|
|
291
|
+
kv_caches, x = self.model(
|
|
292
|
+
kv_caches,
|
|
293
|
+
input_ids,
|
|
294
|
+
attention_metadata,
|
|
295
|
+
)
|
|
296
|
+
return kv_caches, x, []
|
|
297
|
+
|
|
298
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
299
|
+
if self.vllm_config.model_config.hf_config.tie_word_embeddings:
|
|
300
|
+
logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
|
|
301
|
+
else:
|
|
302
|
+
logits = jnp.dot(hidden_states, self.model.lm_head.value)
|
|
303
|
+
return logits
|
|
304
|
+
|
|
305
|
+
def get_metadata_map(self) -> MetadataMap:
|
|
306
|
+
sharding_size = self.mesh.shape["model"]
|
|
307
|
+
|
|
308
|
+
model_config = self.vllm_config.model_config
|
|
309
|
+
hf_config = model_config.hf_config
|
|
310
|
+
|
|
311
|
+
num_heads = hf_config.num_attention_heads
|
|
312
|
+
num_kv_heads = hf_config.num_key_value_heads
|
|
313
|
+
qkv_heads = num_heads + num_kv_heads * 2
|
|
314
|
+
hidden_size = model_config.get_hidden_size()
|
|
315
|
+
|
|
316
|
+
# Pad head_dim for kernel performance.
|
|
317
|
+
head_dim_original = model_config.get_head_size()
|
|
318
|
+
|
|
319
|
+
# Key: path to a HF layer weight
|
|
320
|
+
# Value: path to a nnx layer weight
|
|
321
|
+
name_map = {
|
|
322
|
+
"model.embed_tokens": "model.embed.embedding",
|
|
323
|
+
"model.layers.*.input_layernorm":
|
|
324
|
+
"model.layers.*.input_layernorm.scale",
|
|
325
|
+
"model.layers.*.mlp.down_proj":
|
|
326
|
+
"model.layers.*.mlp.down_proj.kernel",
|
|
327
|
+
"model.layers.*.mlp.gate_up_proj":
|
|
328
|
+
"model.layers.*.mlp.gate_up_proj.kernel",
|
|
329
|
+
"model.layers.*.post_attention_layernorm":
|
|
330
|
+
"model.layers.*.post_attention_layernorm.scale",
|
|
331
|
+
"model.layers.*.self_attn.qkv_proj":
|
|
332
|
+
"model.layers.*.self_attn.qkv_proj.kernel",
|
|
333
|
+
"model.layers.*.self_attn.o_proj":
|
|
334
|
+
"model.layers.*.self_attn.o_proj.kernel",
|
|
335
|
+
"model.norm": "model.norm.scale",
|
|
336
|
+
}
|
|
337
|
+
if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
|
|
338
|
+
name_map.update({
|
|
339
|
+
"lm_head": "model.lm_head",
|
|
340
|
+
})
|
|
341
|
+
|
|
342
|
+
reshape_keys: dict[str, tuple[int, ...]] = {
|
|
343
|
+
"qkv_proj": (qkv_heads, head_dim_original, hidden_size),
|
|
344
|
+
"o_proj": (hidden_size, num_heads, head_dim_original),
|
|
345
|
+
}
|
|
346
|
+
transpose_keys: dict[str, tuple[int, ...]] = {
|
|
347
|
+
"lm_head": (1, 0),
|
|
348
|
+
"gate_up_proj": (1, 0),
|
|
349
|
+
"down_proj": (1, 0),
|
|
350
|
+
"qkv_proj": (2, 0, 1),
|
|
351
|
+
"o_proj": (1, 2, 0),
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
# key: (padding_dim, padding_size)
|
|
355
|
+
pad_keys: dict[str, tuple[int, ...]] = {
|
|
356
|
+
"qkv_proj": (1, sharding_size // num_heads),
|
|
357
|
+
"o_proj": (0, sharding_size // num_heads),
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
return MetadataMap(name_map=name_map,
|
|
361
|
+
reshape_map=reshape_keys,
|
|
362
|
+
bias_reshape_map={},
|
|
363
|
+
transpose_map=transpose_keys,
|
|
364
|
+
pad_map=pad_keys,
|
|
365
|
+
bias_pad_map={})
|
|
366
|
+
|
|
367
|
+
def load_weights(self, rng_key: jax.Array):
|
|
368
|
+
# NOTE: Since we are using nnx.eval_shape to init the model,
|
|
369
|
+
# we have to pass dynamic arrays here for __call__'s usage.
|
|
370
|
+
self.rng = nnx.Rngs(rng_key)
|
|
371
|
+
|
|
372
|
+
metadata_map = self.get_metadata_map()
|
|
373
|
+
load_hf_weights(vllm_config=self.vllm_config,
|
|
374
|
+
model=self,
|
|
375
|
+
metadata_map=metadata_map,
|
|
376
|
+
mesh=self.mesh)
|
|
@@ -8,8 +8,8 @@ from transformers import Qwen2Config, modeling_flax_utils
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
-
from tpu_inference.layers.common.attention_interface import attention
|
|
12
11
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
+
from tpu_inference.layers.jax.attention_interface import attention
|
|
13
13
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
14
|
from tpu_inference.logger import init_logger
|
|
15
15
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
@@ -368,8 +368,7 @@ class Qwen2ForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config.
|
|
372
|
-
self.mesh, mappings)
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
|
|
373
372
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
374
373
|
model=self,
|
|
375
374
|
metadata_map=metadata_map,
|
|
@@ -14,9 +14,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|
|
14
14
|
from vllm.config import VllmConfig
|
|
15
15
|
|
|
16
16
|
from tpu_inference import utils as utils
|
|
17
|
-
from tpu_inference.layers.common.attention_interface import \
|
|
18
|
-
sharded_flash_attention
|
|
19
17
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
18
|
+
from tpu_inference.layers.jax.attention_interface import \
|
|
19
|
+
sharded_flash_attention
|
|
20
20
|
from tpu_inference.logger import init_logger
|
|
21
21
|
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
22
22
|
# from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
|
@@ -486,11 +486,6 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
486
486
|
dtype=dtype,
|
|
487
487
|
rngs=rngs)
|
|
488
488
|
|
|
489
|
-
additional_config = getattr(vllm_config, "additional_config",
|
|
490
|
-
None) or {}
|
|
491
|
-
self.enable_dynamic_image_sizes = additional_config.get(
|
|
492
|
-
"enable_dynamic_image_sizes", False)
|
|
493
|
-
|
|
494
489
|
def rotary_pos_emb_thw(self, t, h, w):
|
|
495
490
|
hpos_ids, wpos_ids = jnp.indices((h, w))
|
|
496
491
|
hpos_ids = hpos_ids.reshape(
|
|
@@ -584,7 +579,21 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
584
579
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
585
580
|
return max_seqlen, seqlens
|
|
586
581
|
|
|
587
|
-
def
|
|
582
|
+
def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
|
|
583
|
+
int]]) -> jax.Array:
|
|
584
|
+
# x: pixel_values: jax.Array
|
|
585
|
+
# """Shape:
|
|
586
|
+
# `(num_patches, num_channels * patch_size * patch_size)`
|
|
587
|
+
# """
|
|
588
|
+
|
|
589
|
+
# grid_thw: image_grid_thw: jax.Array
|
|
590
|
+
# """Shape: `(num_images, 3)`
|
|
591
|
+
# This should be in `(grid_t, grid_h, grid_w)` format.
|
|
592
|
+
# """
|
|
593
|
+
hidden_states = self.patch_embed(x)
|
|
594
|
+
|
|
595
|
+
# num of patches
|
|
596
|
+
seq_len = x.shape[0]
|
|
588
597
|
# num of images/videoes
|
|
589
598
|
num_grids = len(grid_thw)
|
|
590
599
|
|
|
@@ -629,42 +638,6 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
629
638
|
cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
|
|
630
639
|
mode='constant',
|
|
631
640
|
constant_values=0)
|
|
632
|
-
return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
|
|
633
|
-
|
|
634
|
-
def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
|
|
635
|
-
cu_window_seqlens):
|
|
636
|
-
# padding
|
|
637
|
-
num_patches = int(rotary_pos_emb.shape[0])
|
|
638
|
-
bucket_num_patches = 1 << (num_patches - 1).bit_length()
|
|
639
|
-
num_tokens = window_index.shape[0]
|
|
640
|
-
bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
|
|
641
|
-
vit_merger_window_size = (self.window_size //
|
|
642
|
-
self.spatial_merge_size // self.patch_size)
|
|
643
|
-
max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
|
|
644
|
-
|
|
645
|
-
rotary_pos_emb = jnp.pad(rotary_pos_emb,
|
|
646
|
-
((0, bucket_num_patches - num_patches),
|
|
647
|
-
(0, 0)))
|
|
648
|
-
window_index = jnp.concatenate([
|
|
649
|
-
window_index,
|
|
650
|
-
jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
|
|
651
|
-
])
|
|
652
|
-
cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
|
|
653
|
-
pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
|
|
654
|
-
cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
|
|
655
|
-
cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
|
|
656
|
-
|
|
657
|
-
x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
|
|
658
|
-
|
|
659
|
-
return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
|
|
660
|
-
|
|
661
|
-
def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
|
|
662
|
-
rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
|
|
663
|
-
cu_window_seqlens: jax.Array) -> jax.Array:
|
|
664
|
-
hidden_states = self.patch_embed(x)
|
|
665
|
-
|
|
666
|
-
# num of patches
|
|
667
|
-
seq_len = x.shape[0]
|
|
668
641
|
|
|
669
642
|
hidden_states = hidden_states.reshape(
|
|
670
643
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
@@ -691,48 +664,6 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
691
664
|
hidden_states = hidden_states[reverse_indices, :]
|
|
692
665
|
return hidden_states
|
|
693
666
|
|
|
694
|
-
@jax.jit
|
|
695
|
-
def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
|
|
696
|
-
cu_seqlens, cu_window_seqlens):
|
|
697
|
-
return self.compute_hidden_states(x_padded, window_index,
|
|
698
|
-
rotary_pos_emb, cu_seqlens,
|
|
699
|
-
cu_window_seqlens)
|
|
700
|
-
|
|
701
|
-
@partial(
|
|
702
|
-
jax.jit,
|
|
703
|
-
static_argnames=("grid_thw", ),
|
|
704
|
-
)
|
|
705
|
-
def encode_jit(self, x, grid_thw):
|
|
706
|
-
window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
|
|
707
|
-
grid_thw)
|
|
708
|
-
return self.compute_hidden_states(x, window_index, rotary_pos_emb,
|
|
709
|
-
cu_seqlens, cu_window_seqlens)
|
|
710
|
-
|
|
711
|
-
def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
|
|
712
|
-
int]]) -> jax.Array:
|
|
713
|
-
# x: pixel_values: jax.Array
|
|
714
|
-
# """Shape:
|
|
715
|
-
# `(num_patches, num_channels * patch_size * patch_size)`
|
|
716
|
-
# """
|
|
717
|
-
|
|
718
|
-
# grid_thw: image_grid_thw: jax.Array
|
|
719
|
-
# """Shape: `(num_images, 3)`
|
|
720
|
-
# This should be in `(grid_t, grid_h, grid_w)` format.
|
|
721
|
-
# """
|
|
722
|
-
if self.enable_dynamic_image_sizes:
|
|
723
|
-
window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
|
|
724
|
-
grid_thw)
|
|
725
|
-
x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
|
|
726
|
-
x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
|
|
727
|
-
|
|
728
|
-
hidden_states = self.encode_padded_jit(x_padded, window_index,
|
|
729
|
-
rotary_pos_emb, cu_seqlens,
|
|
730
|
-
cu_window_seqlens)
|
|
731
|
-
return hidden_states[:num_tokens]
|
|
732
|
-
|
|
733
|
-
else:
|
|
734
|
-
return self.encode_jit(x, grid_thw)
|
|
735
|
-
|
|
736
667
|
|
|
737
668
|
class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
738
669
|
|
|
@@ -957,6 +888,10 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
957
888
|
# "video"] = self._parse_and_validate_video_input(**kwargs)
|
|
958
889
|
return mm_input_by_modality
|
|
959
890
|
|
|
891
|
+
@partial(
|
|
892
|
+
jax.jit,
|
|
893
|
+
static_argnames=("image_grid_thw", ),
|
|
894
|
+
)
|
|
960
895
|
def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
|
|
961
896
|
return self.visual(image_pixel_values, (image_grid_thw, ))
|
|
962
897
|
|
|
@@ -1126,8 +1061,7 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1126
1061
|
"lm_head": "language_model.model.lm_head",
|
|
1127
1062
|
})
|
|
1128
1063
|
|
|
1129
|
-
metadata_map = get_default_maps(self.vllm_config.
|
|
1130
|
-
self.mesh, mappings)
|
|
1064
|
+
metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
|
|
1131
1065
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
1132
1066
|
model=self,
|
|
1133
1067
|
metadata_map=metadata_map,
|
|
@@ -1137,82 +1071,33 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1137
1071
|
self,
|
|
1138
1072
|
run_compilation_fn: Callable,
|
|
1139
1073
|
) -> None:
|
|
1074
|
+
image_shapes = []
|
|
1075
|
+
if (warmup_config := self.vllm_config.additional_config.get(
|
|
1076
|
+
"vision_warmup_config")):
|
|
1077
|
+
image_shapes = warmup_config.get("image_shapes")
|
|
1078
|
+
|
|
1140
1079
|
vc = self.vllm_config.model_config.hf_config.vision_config
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
for num_patches in num_patches_paddings:
|
|
1160
|
-
dummy_x_padded = jnp.ones(
|
|
1161
|
-
(num_patches, patch_input_dim),
|
|
1162
|
-
dtype=self.vllm_config.model_config.dtype)
|
|
1163
|
-
|
|
1164
|
-
num_tokens = num_patches // spatial_merge_unit
|
|
1165
|
-
dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
|
|
1166
|
-
|
|
1167
|
-
dummy_rotary_pos_emb = jnp.ones(
|
|
1168
|
-
(num_patches, rotary_dim),
|
|
1169
|
-
dtype=self.vllm_config.model_config.dtype)
|
|
1170
|
-
|
|
1171
|
-
dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
|
|
1172
|
-
dtype=jnp.int32)
|
|
1173
|
-
|
|
1174
|
-
max_windows = (num_tokens // vit_merger_window_size) + 2
|
|
1175
|
-
patches_per_window = (vit_merger_window_size**
|
|
1176
|
-
2) * spatial_merge_unit
|
|
1177
|
-
dummy_cu_window_seqlens = jnp.arange(
|
|
1178
|
-
max_windows + 1, dtype=jnp.int32) * patches_per_window
|
|
1179
|
-
dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
|
|
1180
|
-
num_patches)
|
|
1181
|
-
|
|
1182
|
-
run_compilation_fn("vision_encoder_padded",
|
|
1183
|
-
self.visual.encode_padded_jit,
|
|
1184
|
-
dummy_x_padded,
|
|
1185
|
-
dummy_window_index,
|
|
1186
|
-
dummy_rotary_pos_emb,
|
|
1187
|
-
dummy_cu_seqlens,
|
|
1188
|
-
dummy_cu_window_seqlens,
|
|
1189
|
-
num_patches=num_patches)
|
|
1190
|
-
else:
|
|
1191
|
-
image_shapes = []
|
|
1192
|
-
if (warmup_config := self.vllm_config.additional_config.get(
|
|
1193
|
-
"vision_warmup_config")):
|
|
1194
|
-
image_shapes = warmup_config.get("image_shapes")
|
|
1195
|
-
|
|
1196
|
-
factor = vc.patch_size * vc.spatial_merge_size
|
|
1197
|
-
for input_hw in image_shapes:
|
|
1198
|
-
if not isinstance(input_hw, list) or len(input_hw) != 2:
|
|
1199
|
-
logger.warning(f"Skipping invalid shape {input_hw}.")
|
|
1200
|
-
continue
|
|
1201
|
-
h_input, w_input = input_hw
|
|
1202
|
-
h_processed = round(h_input / factor) * factor
|
|
1203
|
-
w_processed = round(w_input / factor) * factor
|
|
1204
|
-
t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
|
|
1205
|
-
grid_thw = (t, h, w)
|
|
1206
|
-
num_patches = t * h * w
|
|
1207
|
-
|
|
1208
|
-
dummy_pixel_values = jnp.ones(
|
|
1209
|
-
(num_patches, patch_input_dim),
|
|
1210
|
-
self.vllm_config.model_config.dtype,
|
|
1211
|
-
)
|
|
1212
|
-
dummy_grid_thw = (grid_thw, )
|
|
1080
|
+
factor = vc.patch_size * vc.spatial_merge_size
|
|
1081
|
+
for input_hw in image_shapes:
|
|
1082
|
+
if not isinstance(input_hw, list) or len(input_hw) != 2:
|
|
1083
|
+
logger.warning(f"Skipping invalid shape {input_hw}.")
|
|
1084
|
+
continue
|
|
1085
|
+
h_input, w_input = input_hw
|
|
1086
|
+
h_processed = round(h_input / factor) * factor
|
|
1087
|
+
w_processed = round(w_input / factor) * factor
|
|
1088
|
+
t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
|
|
1089
|
+
grid_thw = (t, h, w)
|
|
1090
|
+
num_patches = t * h * w
|
|
1091
|
+
patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
1092
|
+
|
|
1093
|
+
dummy_pixel_values = jnp.ones(
|
|
1094
|
+
(num_patches, patch_input_dim),
|
|
1095
|
+
self.vllm_config.model_config.dtype,
|
|
1096
|
+
)
|
|
1097
|
+
dummy_grid_thw = grid_thw
|
|
1213
1098
|
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1099
|
+
run_compilation_fn("single_image_encoder",
|
|
1100
|
+
self.get_single_image_embedding,
|
|
1101
|
+
dummy_pixel_values,
|
|
1102
|
+
dummy_grid_thw,
|
|
1103
|
+
image_shape=input_hw)
|
|
@@ -8,8 +8,8 @@ from transformers import Qwen3Config
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
-
from tpu_inference.layers.common.attention_interface import attention
|
|
12
11
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
+
from tpu_inference.layers.jax.attention_interface import attention
|
|
13
13
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
14
|
from tpu_inference.logger import init_logger
|
|
15
15
|
from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
|
|
@@ -295,8 +295,7 @@ class Qwen3ForCausalLM(nnx.Module):
|
|
|
295
295
|
"lm_head": "model.lm_head",
|
|
296
296
|
})
|
|
297
297
|
|
|
298
|
-
metadata_map = get_default_maps(self.vllm_config.
|
|
299
|
-
self.mesh, mappings)
|
|
298
|
+
metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
|
|
300
299
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
301
300
|
model=self,
|
|
302
301
|
metadata_map=metadata_map,
|