tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,976 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
|
|
4
|
+
Union)
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import numpy as np
|
|
9
|
+
from flax import nnx
|
|
10
|
+
from jax.sharding import Mesh
|
|
11
|
+
from transformers import modeling_flax_utils
|
|
12
|
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|
13
|
+
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
|
14
|
+
from vllm.config import VllmConfig
|
|
15
|
+
from vllm.model_executor.models.qwen2_5_vl import \
|
|
16
|
+
Qwen2_5_VLForConditionalGeneration as vllm_model_cls
|
|
17
|
+
|
|
18
|
+
from tpu_inference import utils as utils
|
|
19
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
20
|
+
from tpu_inference.layers.jax.attention_interface import \
|
|
21
|
+
sharded_flash_attention
|
|
22
|
+
from tpu_inference.logger import init_logger
|
|
23
|
+
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
24
|
+
# from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
|
25
|
+
from tpu_inference.models.jax.utils.multi_modal_utils import (
|
|
26
|
+
MultiModalEmbeddings, merge_multimodal_embeddings)
|
|
27
|
+
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
28
|
+
load_hf_weights)
|
|
29
|
+
|
|
30
|
+
logger = init_logger(__name__)
|
|
31
|
+
|
|
32
|
+
init_fn = nnx.initializers.uniform()
|
|
33
|
+
|
|
34
|
+
DEFAULT_BLOCK_K_MAJOR = 128
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SegmentIds(NamedTuple):
|
|
38
|
+
"""SegmentIds for Q and KV sequences.
|
|
39
|
+
|
|
40
|
+
SegmentIds are used to generate segment mask, which prevents attention between
|
|
41
|
+
different segments in the input sequence. Each array is a list of ids
|
|
42
|
+
(integers).
|
|
43
|
+
Only the token with the same id can attend to each other.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
q: segment ids along the Q sequence.
|
|
47
|
+
kv: segment ids along the KV sequence.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
q: jax.Array # [batch_size, q_seq_len]
|
|
51
|
+
kv: jax.Array # [batch_size, kv_seq_len]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Qwen2_5_VLImagePixelInputs(TypedDict):
|
|
55
|
+
type: Literal["pixel_values"]
|
|
56
|
+
pixel_values: jax.Array
|
|
57
|
+
"""Shape:
|
|
58
|
+
`(num_patches, num_channels * patch_size * patch_size)`
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
image_grid_thw: tuple[tuple[int, int, int], ...]
|
|
62
|
+
"""Shape: `(num_images, 3)`
|
|
63
|
+
This should be in `(grid_t, grid_h, grid_w)` format.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# NOTE: We are not supporting embedding inputs for now
|
|
68
|
+
# The code here makes the struture consistent and
|
|
69
|
+
# makes iteasier for future implementation
|
|
70
|
+
class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
|
|
71
|
+
type: Literal["image_embeds"]
|
|
72
|
+
image_embeds: jax.Array
|
|
73
|
+
"""Supported types:
|
|
74
|
+
- list[`jax.Array`]: A list of tensors holding all images' features.
|
|
75
|
+
Each tensor holds an image's features.
|
|
76
|
+
- `jax.Array`: A tensor holding all images' features (concatenation of
|
|
77
|
+
all images' feature tensors).
|
|
78
|
+
|
|
79
|
+
Tensor shape: `(num_image_features, hidden_size)`
|
|
80
|
+
- `num_image_features` varies based on
|
|
81
|
+
the number and resolution of the images.
|
|
82
|
+
- `hidden_size` must match the hidden size of language model backbone.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
image_grid_thw: jax.Array
|
|
86
|
+
"""Shape: `(num_images, 3)`
|
|
87
|
+
This should be in `(grid_t, grid_h, grid_w)` format.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
|
|
92
|
+
Qwen2_5_VLImageEmbeddingInputs]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class Qwen2_5_VisionMLP(nnx.Module):
|
|
96
|
+
|
|
97
|
+
def __init__(self, config: Qwen2_5_VLVisionConfig, dtype: jnp.dtype,
|
|
98
|
+
rngs: nnx.Rngs):
|
|
99
|
+
in_features = config.hidden_size
|
|
100
|
+
hidden_features = config.intermediate_size
|
|
101
|
+
act_fn = modeling_flax_utils.ACT2FN[config.hidden_act]
|
|
102
|
+
self.gate_proj = nnx.Linear(
|
|
103
|
+
in_features,
|
|
104
|
+
hidden_features,
|
|
105
|
+
use_bias=True,
|
|
106
|
+
param_dtype=dtype,
|
|
107
|
+
kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
|
|
108
|
+
bias_init=nnx.with_partitioning(init_fn, ("model", )),
|
|
109
|
+
rngs=rngs,
|
|
110
|
+
)
|
|
111
|
+
self.up_proj = nnx.Linear(
|
|
112
|
+
in_features,
|
|
113
|
+
hidden_features,
|
|
114
|
+
use_bias=True,
|
|
115
|
+
param_dtype=dtype,
|
|
116
|
+
kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
|
|
117
|
+
bias_init=nnx.with_partitioning(init_fn, ("model", )),
|
|
118
|
+
rngs=rngs,
|
|
119
|
+
)
|
|
120
|
+
self.down_proj = nnx.Linear(
|
|
121
|
+
hidden_features,
|
|
122
|
+
in_features,
|
|
123
|
+
use_bias=True,
|
|
124
|
+
param_dtype=dtype,
|
|
125
|
+
kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
|
|
126
|
+
bias_init=nnx.with_partitioning(init_fn, (None, )),
|
|
127
|
+
rngs=rngs,
|
|
128
|
+
)
|
|
129
|
+
self.act_fn = act_fn
|
|
130
|
+
|
|
131
|
+
def __call__(self, x: jax.Array) -> jax.Array:
|
|
132
|
+
gate = self.act_fn(self.gate_proj(x))
|
|
133
|
+
up = self.up_proj(x)
|
|
134
|
+
fuse = gate * up
|
|
135
|
+
result = self.down_proj(fuse)
|
|
136
|
+
return result
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def apply_rotary_pos_emb_vision(x: jax.Array,
|
|
140
|
+
rotary_pos_emb: jax.Array) -> jax.Array:
|
|
141
|
+
# x: [B, T, N, H]
|
|
142
|
+
# rotary_pos_emb: [T, H//2]
|
|
143
|
+
_, _, _, H = x.shape
|
|
144
|
+
half_dim = H // 2
|
|
145
|
+
|
|
146
|
+
# [B, T, N, H//2]
|
|
147
|
+
x_real = x[..., :half_dim]
|
|
148
|
+
x_imag = x[..., half_dim:]
|
|
149
|
+
|
|
150
|
+
# [T, H//2]
|
|
151
|
+
cos_emb = jnp.cos(rotary_pos_emb)
|
|
152
|
+
sin_emb = jnp.sin(rotary_pos_emb)
|
|
153
|
+
|
|
154
|
+
# [1, T, 1, H//2]
|
|
155
|
+
cos_emb = cos_emb[None, :, None, :]
|
|
156
|
+
sin_emb = sin_emb[None, :, None, :]
|
|
157
|
+
|
|
158
|
+
# [B, T, N, H//2]
|
|
159
|
+
x_rotated_real = x_real * cos_emb - x_imag * sin_emb
|
|
160
|
+
x_rotated_imag = x_real * sin_emb + x_imag * cos_emb
|
|
161
|
+
|
|
162
|
+
# [B, T, N, H]
|
|
163
|
+
x_rotated = jnp.concatenate([x_rotated_real, x_rotated_imag], axis=-1)
|
|
164
|
+
|
|
165
|
+
return x_rotated
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def generate_window_segment_ids(cu_seqlens: jax.Array, seq_len: int,
|
|
169
|
+
padded_seq_len: int) -> SegmentIds:
|
|
170
|
+
"""Generates segment IDs for windowed attention
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
cu_seqlens: A 1D array of cumulative sequence lengths for each window.
|
|
174
|
+
e.g., [0, len_win0, len_win0+len_win1, ...]
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
A SegmentIds object for flash_attention.
|
|
178
|
+
"""
|
|
179
|
+
indices = jnp.arange(seq_len, dtype=jnp.int32)
|
|
180
|
+
segment_ids = jnp.searchsorted(cu_seqlens[1:], indices, side='right') + 1
|
|
181
|
+
padding_segment_ids = jnp.zeros(padded_seq_len - seq_len, dtype=jnp.int32)
|
|
182
|
+
segment_ids = jnp.concatenate([segment_ids, padding_segment_ids])
|
|
183
|
+
segment_ids = segment_ids.reshape(1, -1)
|
|
184
|
+
|
|
185
|
+
return SegmentIds(q=segment_ids, kv=segment_ids)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class Qwen2_5_VisionAttention(nnx.Module):
|
|
189
|
+
|
|
190
|
+
def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
|
|
191
|
+
rngs: nnx.Rngs, mesh: Mesh):
|
|
192
|
+
vision_config = config.vision_config
|
|
193
|
+
self.hidden_size = vision_config.hidden_size
|
|
194
|
+
self.num_heads = vision_config.num_heads
|
|
195
|
+
self.num_kv_heads = self.num_heads
|
|
196
|
+
self.rope_theta = config.rope_theta
|
|
197
|
+
self.rope_scaling = getattr(config, "rope_scaling", None)
|
|
198
|
+
self.head_dim_original = self.hidden_size // self.num_heads
|
|
199
|
+
|
|
200
|
+
sharding_size = mesh.shape["model"]
|
|
201
|
+
self.num_heads = utils.get_padded_num_heads(self.num_heads,
|
|
202
|
+
sharding_size)
|
|
203
|
+
self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
|
|
204
|
+
sharding_size)
|
|
205
|
+
self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
|
|
206
|
+
|
|
207
|
+
# TODO: Wenlong: Do not consider padding for now
|
|
208
|
+
self.head_dim = self.head_dim_original
|
|
209
|
+
|
|
210
|
+
self.mesh = mesh
|
|
211
|
+
|
|
212
|
+
self.qkv_proj = nnx.Linear(
|
|
213
|
+
self.hidden_size,
|
|
214
|
+
3 * self.hidden_size,
|
|
215
|
+
use_bias=True,
|
|
216
|
+
param_dtype=dtype,
|
|
217
|
+
kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
|
|
218
|
+
bias_init=nnx.with_partitioning(init_fn, ("model", )),
|
|
219
|
+
rngs=rngs,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
self.proj = nnx.Linear(
|
|
223
|
+
self.hidden_size,
|
|
224
|
+
self.hidden_size,
|
|
225
|
+
use_bias=True,
|
|
226
|
+
param_dtype=dtype,
|
|
227
|
+
kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
|
|
228
|
+
bias_init=nnx.with_partitioning(init_fn, (None, )),
|
|
229
|
+
rngs=rngs,
|
|
230
|
+
)
|
|
231
|
+
self.flash_attention = sharded_flash_attention(
|
|
232
|
+
mesh=mesh,
|
|
233
|
+
causal=False,
|
|
234
|
+
sm_scale=1.0 / math.sqrt(self.head_dim),
|
|
235
|
+
vmem_limit_bytes=128 * 1024 * 1024,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def __call__(
|
|
239
|
+
self,
|
|
240
|
+
x: jax.Array,
|
|
241
|
+
rotary_pos_emb: jax.Array,
|
|
242
|
+
cu_window_seqlens: Optional[jax.Array] = None,
|
|
243
|
+
use_fullattn: bool = True,
|
|
244
|
+
) -> jax.Array:
|
|
245
|
+
T, B, D = x.shape
|
|
246
|
+
assert B == 1, "Vision attention currently only supports batch size 1"
|
|
247
|
+
# [T, B, D] -> [T, B, 3 * D]
|
|
248
|
+
qkv = self.qkv_proj(x)
|
|
249
|
+
|
|
250
|
+
# Split into Q, K, V.
|
|
251
|
+
# NOTE: simplified from vLLM's split_qkv,
|
|
252
|
+
# may need to revisit for tp>1
|
|
253
|
+
# [T, B, 3 * D] -> 3 *[T, B, D]
|
|
254
|
+
q, k, v = jnp.split(qkv, 3, axis=-1)
|
|
255
|
+
|
|
256
|
+
# [T, B, N, H]
|
|
257
|
+
q = q.reshape(T, B, self.num_heads, self.head_dim)
|
|
258
|
+
k = k.reshape(T, B, self.num_heads, self.head_dim)
|
|
259
|
+
v = v.reshape(T, B, self.num_heads, self.head_dim)
|
|
260
|
+
|
|
261
|
+
# [T, B, N, H] -> [B, T, N, H]
|
|
262
|
+
q = jnp.transpose(q, (1, 0, 2, 3))
|
|
263
|
+
k = jnp.transpose(k, (1, 0, 2, 3))
|
|
264
|
+
v = jnp.transpose(v, (1, 0, 2, 3))
|
|
265
|
+
|
|
266
|
+
# rotary_pos_emb shape: (T, H)
|
|
267
|
+
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
|
268
|
+
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
|
269
|
+
|
|
270
|
+
# NOTE: an extra transpose because we need to
|
|
271
|
+
# align the correctness with vLLM's design.
|
|
272
|
+
# Might be able to remove one once implemented.
|
|
273
|
+
# [B, T, N, H] -> [B, N, T, H]
|
|
274
|
+
q = jnp.transpose(q, (0, 2, 1, 3))
|
|
275
|
+
k = jnp.transpose(k, (0, 2, 1, 3))
|
|
276
|
+
v = jnp.transpose(v, (0, 2, 1, 3))
|
|
277
|
+
|
|
278
|
+
# Pad the sequence length to be a multiple of 128 for flash_attention
|
|
279
|
+
block_k_major = DEFAULT_BLOCK_K_MAJOR
|
|
280
|
+
T_attn = q.shape[2]
|
|
281
|
+
padded_T = (T_attn + block_k_major -
|
|
282
|
+
1) // block_k_major * block_k_major
|
|
283
|
+
pad_width = ((0, 0), (0, 0), (0, padded_T - T_attn), (0, 0))
|
|
284
|
+
|
|
285
|
+
q = jnp.pad(q, pad_width, 'constant')
|
|
286
|
+
k = jnp.pad(k, pad_width, 'constant')
|
|
287
|
+
v = jnp.pad(v, pad_width, 'constant')
|
|
288
|
+
|
|
289
|
+
segment_ids = generate_window_segment_ids(cu_window_seqlens, T_attn,
|
|
290
|
+
padded_T)
|
|
291
|
+
|
|
292
|
+
# TODO (jacobplatin): add support for quantized KV cache?
|
|
293
|
+
output = self.flash_attention(q, k, v, segment_ids)
|
|
294
|
+
|
|
295
|
+
# Unpad the output
|
|
296
|
+
output = output[:, :, :T_attn, :]
|
|
297
|
+
|
|
298
|
+
# [B, N, T, H] -> [T, B, N, H]
|
|
299
|
+
output = jnp.transpose(output, (2, 0, 1, 3))
|
|
300
|
+
|
|
301
|
+
output = output.reshape(T, B, D)
|
|
302
|
+
|
|
303
|
+
output = self.proj(output)
|
|
304
|
+
|
|
305
|
+
return output
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class Qwen2_5_VisionBlock(nnx.Module):
|
|
309
|
+
|
|
310
|
+
def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
|
|
311
|
+
rngs: nnx.Rngs, mesh: Mesh):
|
|
312
|
+
vision_config = config.vision_config
|
|
313
|
+
dim = vision_config.hidden_size
|
|
314
|
+
norm_layer = partial(nnx.RMSNorm,
|
|
315
|
+
epsilon=config.rms_norm_eps,
|
|
316
|
+
scale_init=nnx.with_partitioning(
|
|
317
|
+
init_fn, (None, )))
|
|
318
|
+
|
|
319
|
+
self.norm1 = norm_layer(dim, dtype=dtype, rngs=rngs)
|
|
320
|
+
self.norm2 = norm_layer(dim, dtype=dtype, rngs=rngs)
|
|
321
|
+
self.attn = Qwen2_5_VisionAttention(config=config,
|
|
322
|
+
dtype=dtype,
|
|
323
|
+
rngs=rngs,
|
|
324
|
+
mesh=mesh)
|
|
325
|
+
self.mlp = Qwen2_5_VisionMLP(config=vision_config,
|
|
326
|
+
dtype=dtype,
|
|
327
|
+
rngs=rngs)
|
|
328
|
+
|
|
329
|
+
def __call__(self,
|
|
330
|
+
x: jax.Array,
|
|
331
|
+
rotary_pos_emb: jax.Array,
|
|
332
|
+
cu_window_seqlens: Optional[jax.Array] = None,
|
|
333
|
+
use_fullattn: bool = True) -> jax.Array:
|
|
334
|
+
|
|
335
|
+
x = x + self.attn(self.norm1(x), rotary_pos_emb, cu_window_seqlens,
|
|
336
|
+
use_fullattn)
|
|
337
|
+
x = x + self.mlp(self.norm2(x))
|
|
338
|
+
|
|
339
|
+
return x
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class Qwen2_5_VisionPatchEmbed(nnx.Module):
|
|
343
|
+
|
|
344
|
+
def __init__(
|
|
345
|
+
self,
|
|
346
|
+
rngs: nnx.Rngs,
|
|
347
|
+
patch_size: int = 14,
|
|
348
|
+
temporal_patch_size: int = 2,
|
|
349
|
+
in_channels: int = 3,
|
|
350
|
+
hidden_size: int = 1152,
|
|
351
|
+
dtype: jnp.dtype = jnp.bfloat16,
|
|
352
|
+
) -> None:
|
|
353
|
+
self.patch_size = patch_size
|
|
354
|
+
self.temporal_patch_size = temporal_patch_size
|
|
355
|
+
self.hidden_size = hidden_size
|
|
356
|
+
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
|
357
|
+
self.proj = nnx.Conv(in_features=in_channels,
|
|
358
|
+
out_features=hidden_size,
|
|
359
|
+
kernel_size=kernel_size,
|
|
360
|
+
strides=kernel_size,
|
|
361
|
+
use_bias=False,
|
|
362
|
+
param_dtype=dtype,
|
|
363
|
+
kernel_init=nnx.with_partitioning(
|
|
364
|
+
init_fn, (None, None, None, None, "model")),
|
|
365
|
+
rngs=rngs)
|
|
366
|
+
|
|
367
|
+
def __call__(self, x: jax.Array) -> jax.Array:
|
|
368
|
+
# x is (L, C * T * H * W)
|
|
369
|
+
L, dim = x.shape
|
|
370
|
+
C = dim // (self.temporal_patch_size * self.patch_size *
|
|
371
|
+
self.patch_size)
|
|
372
|
+
# Reshape to (L, T, H, W, C) for Conv3D with channels_last
|
|
373
|
+
x = x.reshape(L, C, self.temporal_patch_size, self.patch_size,
|
|
374
|
+
self.patch_size)
|
|
375
|
+
# L,T,H,W,C
|
|
376
|
+
x = jnp.transpose(x, (0, 2, 3, 4, 1))
|
|
377
|
+
x = self.proj(x)
|
|
378
|
+
# After conv, shape is (L, T_out, H_out, W_out, C_out)
|
|
379
|
+
# With stride=kernel_size, T_out=H_out=W_out=1.
|
|
380
|
+
# So shape is (L, 1, 1, 1, hidden_size)
|
|
381
|
+
x = x.reshape(L, self.hidden_size)
|
|
382
|
+
return x
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class Qwen2_5_VisionPatchMerger(nnx.Module):
|
|
386
|
+
|
|
387
|
+
def __init__(self, d_model: int, context_dim: int, norm_layer: Callable,
|
|
388
|
+
spatial_merge_size: int, dtype: jnp.dtype, rngs: nnx.Rngs):
|
|
389
|
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
|
390
|
+
self.ln_q = norm_layer(context_dim,
|
|
391
|
+
dtype=dtype,
|
|
392
|
+
rngs=rngs,
|
|
393
|
+
scale_init=nnx.with_partitioning(
|
|
394
|
+
init_fn, (None, )))
|
|
395
|
+
self.mlp_fc1 = nnx.Linear(
|
|
396
|
+
self.hidden_size,
|
|
397
|
+
self.hidden_size,
|
|
398
|
+
use_bias=True,
|
|
399
|
+
param_dtype=dtype,
|
|
400
|
+
kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
|
|
401
|
+
bias_init=nnx.with_partitioning(init_fn, ("model", )),
|
|
402
|
+
rngs=rngs)
|
|
403
|
+
self.mlp_act = modeling_flax_utils.ACT2FN["gelu"]
|
|
404
|
+
self.mlp_fc2 = nnx.Linear(
|
|
405
|
+
self.hidden_size,
|
|
406
|
+
d_model,
|
|
407
|
+
use_bias=True,
|
|
408
|
+
param_dtype=dtype,
|
|
409
|
+
kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
|
|
410
|
+
bias_init=nnx.with_partitioning(init_fn, (None, )),
|
|
411
|
+
rngs=rngs)
|
|
412
|
+
|
|
413
|
+
def __call__(self, x: jax.Array) -> jax.Array:
|
|
414
|
+
x = self.ln_q(x)
|
|
415
|
+
x = x.reshape(-1, self.hidden_size)
|
|
416
|
+
x = self.mlp_fc1(x)
|
|
417
|
+
x = self.mlp_act(x)
|
|
418
|
+
x = self.mlp_fc2(x)
|
|
419
|
+
return x
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class Qwen2_5_VisionRotaryEmbedding(nnx.Module):
|
|
423
|
+
|
|
424
|
+
def __init__(self, dim: int, theta: float = 10000.0):
|
|
425
|
+
self.dim = dim
|
|
426
|
+
self.theta = theta
|
|
427
|
+
|
|
428
|
+
def __call__(self, seqlen: int) -> jax.Array:
|
|
429
|
+
inv_freq = 1.0 / (self.theta**(
|
|
430
|
+
jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
|
|
431
|
+
seq = jnp.arange(seqlen, dtype=jnp.float32)
|
|
432
|
+
freqs = jnp.outer(seq, inv_freq)
|
|
433
|
+
return freqs.astype(jnp.bfloat16)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class Qwen2_5_VisionTransformer(nnx.Module):
|
|
437
|
+
|
|
438
|
+
def __init__(self,
|
|
439
|
+
vllm_config: VllmConfig,
|
|
440
|
+
rngs: nnx.Rngs,
|
|
441
|
+
mesh: Mesh,
|
|
442
|
+
norm_eps: float = 1e-6):
|
|
443
|
+
model_config = vllm_config.model_config
|
|
444
|
+
hf_config = model_config.hf_config
|
|
445
|
+
vision_config = hf_config.vision_config
|
|
446
|
+
dtype = model_config.dtype
|
|
447
|
+
|
|
448
|
+
self.config = vision_config
|
|
449
|
+
self.dtype = dtype
|
|
450
|
+
|
|
451
|
+
patch_size = vision_config.patch_size
|
|
452
|
+
temporal_patch_size = vision_config.temporal_patch_size
|
|
453
|
+
in_channels = vision_config.in_channels
|
|
454
|
+
self.hidden_size = vision_config.hidden_size
|
|
455
|
+
self.num_heads = vision_config.num_heads
|
|
456
|
+
|
|
457
|
+
# args for get_window_index_thw
|
|
458
|
+
self.window_size = vision_config.window_size
|
|
459
|
+
self.patch_size = vision_config.patch_size
|
|
460
|
+
self.spatial_merge_size = vision_config.spatial_merge_size
|
|
461
|
+
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
|
462
|
+
self.spatial_merge_unit = self.spatial_merge_size**2
|
|
463
|
+
|
|
464
|
+
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
|
465
|
+
patch_size=patch_size,
|
|
466
|
+
temporal_patch_size=temporal_patch_size,
|
|
467
|
+
in_channels=in_channels,
|
|
468
|
+
hidden_size=self.hidden_size,
|
|
469
|
+
dtype=dtype,
|
|
470
|
+
rngs=rngs)
|
|
471
|
+
|
|
472
|
+
head_dim = vision_config.hidden_size // vision_config.num_heads
|
|
473
|
+
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
|
474
|
+
|
|
475
|
+
self.blocks = [
|
|
476
|
+
Qwen2_5_VisionBlock(
|
|
477
|
+
config=hf_config,
|
|
478
|
+
dtype=dtype,
|
|
479
|
+
rngs=rngs,
|
|
480
|
+
mesh=mesh,
|
|
481
|
+
) for _ in range(vision_config.depth)
|
|
482
|
+
]
|
|
483
|
+
self.merger = Qwen2_5_VisionPatchMerger(
|
|
484
|
+
d_model=vision_config.out_hidden_size,
|
|
485
|
+
context_dim=vision_config.hidden_size,
|
|
486
|
+
norm_layer=partial(nnx.RMSNorm, epsilon=norm_eps),
|
|
487
|
+
spatial_merge_size=vision_config.spatial_merge_size,
|
|
488
|
+
dtype=dtype,
|
|
489
|
+
rngs=rngs)
|
|
490
|
+
|
|
491
|
+
def rotary_pos_emb_thw(self, t, h, w):
|
|
492
|
+
hpos_ids, wpos_ids = jnp.indices((h, w))
|
|
493
|
+
hpos_ids = hpos_ids.reshape(
|
|
494
|
+
h // self.spatial_merge_size,
|
|
495
|
+
self.spatial_merge_size,
|
|
496
|
+
w // self.spatial_merge_size,
|
|
497
|
+
self.spatial_merge_size,
|
|
498
|
+
).transpose(0, 2, 1, 3).flatten()
|
|
499
|
+
wpos_ids = wpos_ids.reshape(
|
|
500
|
+
h // self.spatial_merge_size,
|
|
501
|
+
self.spatial_merge_size,
|
|
502
|
+
w // self.spatial_merge_size,
|
|
503
|
+
self.spatial_merge_size,
|
|
504
|
+
).transpose(0, 2, 1, 3).flatten()
|
|
505
|
+
pos_ids = jnp.stack([hpos_ids, wpos_ids], axis=-1)
|
|
506
|
+
pos_ids = jnp.tile(pos_ids, (t, 1))
|
|
507
|
+
|
|
508
|
+
max_size = max(h, w)
|
|
509
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_size)
|
|
510
|
+
|
|
511
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].reshape(
|
|
512
|
+
pos_ids.shape[0], -1)
|
|
513
|
+
rotary_pos_emb = rotary_pos_emb.reshape(
|
|
514
|
+
rotary_pos_emb.shape[0] // self.spatial_merge_unit,
|
|
515
|
+
self.spatial_merge_unit, -1)
|
|
516
|
+
|
|
517
|
+
return rotary_pos_emb
|
|
518
|
+
|
|
519
|
+
def get_window_index_thw(self, grid_t, grid_h, grid_w):
|
|
520
|
+
vit_merger_window_size = (self.window_size //
|
|
521
|
+
self.spatial_merge_size // self.patch_size)
|
|
522
|
+
|
|
523
|
+
llm_grid_h = grid_h // self.spatial_merge_size
|
|
524
|
+
llm_grid_w = grid_w // self.spatial_merge_size
|
|
525
|
+
|
|
526
|
+
index = jnp.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
|
527
|
+
grid_t, llm_grid_h, llm_grid_w)
|
|
528
|
+
|
|
529
|
+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
|
530
|
+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
|
531
|
+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
|
532
|
+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
|
533
|
+
|
|
534
|
+
index_padded = jnp.pad(index, ((0, 0), (0, pad_h), (0, pad_w)),
|
|
535
|
+
constant_values=-100)
|
|
536
|
+
index_padded = index_padded.reshape(grid_t, num_windows_h,
|
|
537
|
+
vit_merger_window_size,
|
|
538
|
+
num_windows_w,
|
|
539
|
+
vit_merger_window_size)
|
|
540
|
+
index_padded = jnp.transpose(index_padded, (0, 1, 3, 2, 4)).reshape(
|
|
541
|
+
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
|
|
542
|
+
vit_merger_window_size)
|
|
543
|
+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
|
544
|
+
index_padded = index_padded.reshape(-1)
|
|
545
|
+
# The number of valid indices is static because grid_t, grid_h, grid_w
|
|
546
|
+
# are static.
|
|
547
|
+
num_valid_indices = grid_t * llm_grid_h * llm_grid_w
|
|
548
|
+
valid_indices = jnp.nonzero(index_padded != -100,
|
|
549
|
+
size=num_valid_indices)[0]
|
|
550
|
+
index_new = index_padded[valid_indices]
|
|
551
|
+
cu_seqlens_tmp = jnp.cumsum(seqlens) * self.spatial_merge_unit
|
|
552
|
+
cu_seqlens_tmp = cu_seqlens_tmp.astype(jnp.int32)
|
|
553
|
+
|
|
554
|
+
# NOTE (wenlong): Pytorch code uses this to reduce replication,
|
|
555
|
+
# but I don't think there is a need here, plus it would cause problem in JIT
|
|
556
|
+
# Please refer here if there is a problem down-stream
|
|
557
|
+
# cu_seqlens_tmp = jnp.unique(cu_seqlens_tmp)
|
|
558
|
+
|
|
559
|
+
return index_new, cu_seqlens_tmp
|
|
560
|
+
|
|
561
|
+
def get_rope_by_thw(self, t, h, w):
|
|
562
|
+
window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(
|
|
563
|
+
t, h, w)
|
|
564
|
+
|
|
565
|
+
rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
|
|
566
|
+
|
|
567
|
+
rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
|
|
568
|
+
rotary_pos_emb_thw = rotary_pos_emb_thw.reshape(
|
|
569
|
+
-1, rotary_pos_emb_thw.shape[-1])
|
|
570
|
+
cu_seqlens_thw = jnp.full(t, h * w, dtype=jnp.int32)
|
|
571
|
+
|
|
572
|
+
return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw,
|
|
573
|
+
cu_seqlens_thw)
|
|
574
|
+
|
|
575
|
+
def compute_attn_mask_seqlen(
|
|
576
|
+
self,
|
|
577
|
+
cu_seqlens: jax.Array,
|
|
578
|
+
) -> tuple[Optional[int], Optional[list[int]]]:
|
|
579
|
+
max_seqlen, seqlens = None
|
|
580
|
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
|
581
|
+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
582
|
+
return max_seqlen, seqlens
|
|
583
|
+
|
|
584
|
+
def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
|
|
585
|
+
int]]) -> jax.Array:
|
|
586
|
+
# x: pixel_values: jax.Array
|
|
587
|
+
# """Shape:
|
|
588
|
+
# `(num_patches, num_channels * patch_size * patch_size)`
|
|
589
|
+
# """
|
|
590
|
+
|
|
591
|
+
# grid_thw: image_grid_thw: jax.Array
|
|
592
|
+
# """Shape: `(num_images, 3)`
|
|
593
|
+
# This should be in `(grid_t, grid_h, grid_w)` format.
|
|
594
|
+
# """
|
|
595
|
+
hidden_states = self.patch_embed(x)
|
|
596
|
+
|
|
597
|
+
# num of patches
|
|
598
|
+
seq_len = x.shape[0]
|
|
599
|
+
# num of images/videoes
|
|
600
|
+
num_grids = len(grid_thw)
|
|
601
|
+
|
|
602
|
+
rotary_pos_emb = []
|
|
603
|
+
window_index: list = []
|
|
604
|
+
cu_window_seqlens: list = [jnp.array([0], dtype=jnp.int32)]
|
|
605
|
+
cu_seqlens: list = []
|
|
606
|
+
|
|
607
|
+
window_index_id = 0
|
|
608
|
+
cu_window_seqlens_last = 0
|
|
609
|
+
for i in range(num_grids):
|
|
610
|
+
t, h, w = grid_thw[i]
|
|
611
|
+
|
|
612
|
+
llm_h = h // self.spatial_merge_size
|
|
613
|
+
llm_w = w // self.spatial_merge_size
|
|
614
|
+
|
|
615
|
+
(
|
|
616
|
+
rotary_pos_emb_thw,
|
|
617
|
+
window_index_thw,
|
|
618
|
+
cu_seqlens_window_thw,
|
|
619
|
+
cu_seqlens_thw,
|
|
620
|
+
) = self.get_rope_by_thw(t, h, w)
|
|
621
|
+
|
|
622
|
+
window_index.append(window_index_thw + window_index_id)
|
|
623
|
+
window_index_id += (t * llm_h * llm_w)
|
|
624
|
+
|
|
625
|
+
cu_seqlens_window_thw = (cu_seqlens_window_thw +
|
|
626
|
+
cu_window_seqlens_last)
|
|
627
|
+
cu_window_seqlens_last = cu_seqlens_window_thw[-1]
|
|
628
|
+
cu_window_seqlens.append(cu_seqlens_window_thw)
|
|
629
|
+
|
|
630
|
+
rotary_pos_emb.append(rotary_pos_emb_thw)
|
|
631
|
+
|
|
632
|
+
cu_seqlens.append(cu_seqlens_thw)
|
|
633
|
+
|
|
634
|
+
rotary_pos_emb = jnp.concatenate(rotary_pos_emb, axis=0)
|
|
635
|
+
window_index = jnp.concatenate(window_index, axis=0)
|
|
636
|
+
cu_window_seqlens = jnp.concatenate(cu_window_seqlens, axis=0)
|
|
637
|
+
|
|
638
|
+
cu_seqlens = jnp.concatenate(cu_seqlens, axis=0)
|
|
639
|
+
cu_seqlens = jnp.cumsum(cu_seqlens, axis=0, dtype=jnp.int32)
|
|
640
|
+
cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
|
|
641
|
+
mode='constant',
|
|
642
|
+
constant_values=0)
|
|
643
|
+
|
|
644
|
+
hidden_states = hidden_states.reshape(
|
|
645
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
646
|
+
hidden_states = hidden_states[window_index, :, :]
|
|
647
|
+
hidden_states = hidden_states.reshape(seq_len, -1)
|
|
648
|
+
|
|
649
|
+
hidden_states = jnp.expand_dims(hidden_states, axis=1)
|
|
650
|
+
|
|
651
|
+
for layer_num, blk in enumerate(self.blocks):
|
|
652
|
+
if layer_num in self.fullatt_block_indexes:
|
|
653
|
+
hidden_states = blk(hidden_states,
|
|
654
|
+
rotary_pos_emb=rotary_pos_emb,
|
|
655
|
+
cu_window_seqlens=cu_seqlens,
|
|
656
|
+
use_fullattn=True)
|
|
657
|
+
else:
|
|
658
|
+
hidden_states = blk(hidden_states,
|
|
659
|
+
rotary_pos_emb=rotary_pos_emb,
|
|
660
|
+
cu_window_seqlens=cu_window_seqlens,
|
|
661
|
+
use_fullattn=False)
|
|
662
|
+
|
|
663
|
+
# adapter
|
|
664
|
+
hidden_states = self.merger(hidden_states)
|
|
665
|
+
reverse_indices = jnp.argsort(window_index)
|
|
666
|
+
hidden_states = hidden_states[reverse_indices, :]
|
|
667
|
+
return hidden_states
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
671
|
+
|
|
672
|
+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
|
|
673
|
+
mesh: Mesh) -> None:
|
|
674
|
+
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
|
675
|
+
multimodal_config = vllm_config.model_config.multimodal_config
|
|
676
|
+
|
|
677
|
+
self.vllm_config = vllm_config
|
|
678
|
+
self.rng = nnx.Rngs(rng_key)
|
|
679
|
+
self.mesh = mesh
|
|
680
|
+
|
|
681
|
+
self.config = config
|
|
682
|
+
self.multimodal_config = multimodal_config
|
|
683
|
+
|
|
684
|
+
self.visual = Qwen2_5_VisionTransformer(
|
|
685
|
+
vllm_config=vllm_config,
|
|
686
|
+
rngs=self.rng,
|
|
687
|
+
mesh=mesh,
|
|
688
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
|
689
|
+
)
|
|
690
|
+
self.language_model = Qwen2ForCausalLM(vllm_config, rng_key, mesh)
|
|
691
|
+
|
|
692
|
+
@classmethod
|
|
693
|
+
def get_mrope_input_positions(
|
|
694
|
+
cls,
|
|
695
|
+
input_tokens: list[int],
|
|
696
|
+
hf_config,
|
|
697
|
+
image_grid_thw,
|
|
698
|
+
video_grid_thw,
|
|
699
|
+
second_per_grid_ts: list[float],
|
|
700
|
+
context_len: int = 0,
|
|
701
|
+
seq_len: int | None = None,
|
|
702
|
+
audio_feature_lengths=None,
|
|
703
|
+
use_audio_in_video: bool = False,
|
|
704
|
+
):
|
|
705
|
+
return vllm_model_cls.get_mrope_input_positions(
|
|
706
|
+
input_tokens=input_tokens,
|
|
707
|
+
hf_config=hf_config,
|
|
708
|
+
image_grid_thw=image_grid_thw,
|
|
709
|
+
video_grid_thw=video_grid_thw,
|
|
710
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
711
|
+
context_len=context_len,
|
|
712
|
+
seq_len=seq_len,
|
|
713
|
+
audio_feature_lengths=audio_feature_lengths,
|
|
714
|
+
use_audio_in_video=use_audio_in_video,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
|
718
|
+
name: str) -> jax.Array:
|
|
719
|
+
if isinstance(mm_input, list):
|
|
720
|
+
# Assuming it's a list of arrays (e.g., np.ndarray, torch.Tensor)
|
|
721
|
+
# that can be concatenated.
|
|
722
|
+
arrays_to_concat = [jnp.asarray(item) for item in mm_input]
|
|
723
|
+
return jnp.concatenate(arrays_to_concat, axis=0)
|
|
724
|
+
|
|
725
|
+
# Handle single array-like objects (np.ndarray, torch.Tensor, jax.Array)
|
|
726
|
+
if hasattr(mm_input, 'ndim'):
|
|
727
|
+
array_input = jnp.asarray(mm_input)
|
|
728
|
+
if array_input.ndim == 2:
|
|
729
|
+
return array_input
|
|
730
|
+
if array_input.ndim == 3:
|
|
731
|
+
# This reshapes the batched 3D tensor to a 2D tensor.
|
|
732
|
+
return array_input.reshape(-1, array_input.shape[-1])
|
|
733
|
+
|
|
734
|
+
raise ValueError(f"Incorrect type of {name}. "
|
|
735
|
+
f"Got type: {type(mm_input)}")
|
|
736
|
+
|
|
737
|
+
def _parse_and_validate_image_input(
|
|
738
|
+
self, image_grid_thw: tuple[tuple[int, int, int], ...],
|
|
739
|
+
**kwargs: object) -> Optional[Qwen2_5_VLImageInputs]:
|
|
740
|
+
pixel_values = kwargs.pop("pixel_values", None)
|
|
741
|
+
image_embeds = kwargs.pop("image_embeds", None)
|
|
742
|
+
# image_grid_thw = kwargs.pop("image_grid_thw", None)
|
|
743
|
+
|
|
744
|
+
if pixel_values is None and image_embeds is None:
|
|
745
|
+
return None
|
|
746
|
+
|
|
747
|
+
if pixel_values is not None:
|
|
748
|
+
pixel_values = self._validate_and_reshape_mm_tensor(
|
|
749
|
+
pixel_values, "image pixel values")
|
|
750
|
+
# image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
751
|
+
# image_grid_thw, "image grid_thw")
|
|
752
|
+
|
|
753
|
+
if not isinstance(pixel_values, jax.Array):
|
|
754
|
+
raise ValueError("Incorrect type of image pixel values. "
|
|
755
|
+
f"Got type: {type(pixel_values)}")
|
|
756
|
+
|
|
757
|
+
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
|
|
758
|
+
pixel_values=pixel_values,
|
|
759
|
+
image_grid_thw=image_grid_thw)
|
|
760
|
+
|
|
761
|
+
# Note: comment them out for now and save for future support
|
|
762
|
+
# if image_embeds is not None:
|
|
763
|
+
# image_embeds = self._validate_and_reshape_mm_tensor(
|
|
764
|
+
# image_embeds, "image embeds")
|
|
765
|
+
# image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
766
|
+
# image_grid_thw, "image grid_thw")
|
|
767
|
+
|
|
768
|
+
# if not isinstance(image_embeds, jax.Array):
|
|
769
|
+
# raise ValueError("Incorrect type of image embeddings. "
|
|
770
|
+
# f"Got type: {type(image_embeds)}")
|
|
771
|
+
# return Qwen2_5_VLImageEmbeddingInputs(
|
|
772
|
+
# type="image_embeds",
|
|
773
|
+
# image_embeds=image_embeds,
|
|
774
|
+
# image_grid_thw=image_grid_thw)
|
|
775
|
+
|
|
776
|
+
def _parse_and_validate_multimodal_inputs(self,
|
|
777
|
+
image_grid_thw: tuple[tuple[int,
|
|
778
|
+
int,
|
|
779
|
+
int],
|
|
780
|
+
...],
|
|
781
|
+
**kwargs: object) -> dict:
|
|
782
|
+
mm_input_by_modality = {}
|
|
783
|
+
|
|
784
|
+
# Preserve the order of modalities if there are multiple of them
|
|
785
|
+
# from the order of kwargs.
|
|
786
|
+
for input_key in kwargs:
|
|
787
|
+
if input_key in ("pixel_values", "image_embeds"
|
|
788
|
+
) and "image" not in mm_input_by_modality:
|
|
789
|
+
mm_input_by_modality[
|
|
790
|
+
"image"] = self._parse_and_validate_image_input(
|
|
791
|
+
image_grid_thw, **kwargs)
|
|
792
|
+
# if input_key in ("pixel_values_videos", "video_embeds"
|
|
793
|
+
# ) and "video" not in mm_input_by_modality:
|
|
794
|
+
# mm_input_by_modality[
|
|
795
|
+
# "video"] = self._parse_and_validate_video_input(**kwargs)
|
|
796
|
+
return mm_input_by_modality
|
|
797
|
+
|
|
798
|
+
@partial(
|
|
799
|
+
jax.jit,
|
|
800
|
+
static_argnames=("image_grid_thw", ),
|
|
801
|
+
)
|
|
802
|
+
def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
|
|
803
|
+
return self.visual(image_pixel_values, (image_grid_thw, ))
|
|
804
|
+
|
|
805
|
+
def _process_image_input(
|
|
806
|
+
self, image_input: Qwen2_5_VLImageInputs) -> tuple[jax.Array, ...]:
|
|
807
|
+
|
|
808
|
+
grid_thw = image_input["image_grid_thw"]
|
|
809
|
+
|
|
810
|
+
if image_input["type"] == "image_embeds":
|
|
811
|
+
image_embeds = image_input["image_embeds"].astype(
|
|
812
|
+
self.visual.dtype)
|
|
813
|
+
else:
|
|
814
|
+
pixel_values = image_input["pixel_values"]
|
|
815
|
+
image_embeds = []
|
|
816
|
+
current_idx = 0
|
|
817
|
+
for image_thw in grid_thw:
|
|
818
|
+
t, h, w = image_thw
|
|
819
|
+
image_size = t * h * w
|
|
820
|
+
end_idx = current_idx + image_size
|
|
821
|
+
image_pixel_values = pixel_values[current_idx:end_idx, :]
|
|
822
|
+
image_embeds.append(
|
|
823
|
+
self.get_single_image_embedding(image_pixel_values,
|
|
824
|
+
image_thw))
|
|
825
|
+
current_idx = end_idx
|
|
826
|
+
image_embeds = jnp.concatenate(image_embeds, axis=0)
|
|
827
|
+
|
|
828
|
+
# Split concatenated embeddings for each image item.
|
|
829
|
+
merge_size = self.visual.config.spatial_merge_size
|
|
830
|
+
sizes = np.prod(np.array(grid_thw, dtype=np.int64),
|
|
831
|
+
axis=-1) // merge_size // merge_size
|
|
832
|
+
|
|
833
|
+
if sizes.size == 0:
|
|
834
|
+
return ()
|
|
835
|
+
if sizes.size == 1:
|
|
836
|
+
return (image_embeds, )
|
|
837
|
+
|
|
838
|
+
split_indices = np.cumsum(sizes)[:-1]
|
|
839
|
+
return tuple(jnp.split(image_embeds, split_indices))
|
|
840
|
+
|
|
841
|
+
def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
|
|
842
|
+
int], ...],
|
|
843
|
+
**kwargs: object) -> MultiModalEmbeddings:
|
|
844
|
+
|
|
845
|
+
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
|
846
|
+
image_grid_thw, **kwargs)
|
|
847
|
+
if not mm_input_by_modality:
|
|
848
|
+
return []
|
|
849
|
+
|
|
850
|
+
# The result multimodal_embeddings is tuple of tensors, with each
|
|
851
|
+
# tensor correspoending to a multimodal data item (image or video).
|
|
852
|
+
multimodal_embeddings: tuple[jax.Array, ...] = ()
|
|
853
|
+
|
|
854
|
+
# NOTE: It is important to iterate over the keys in this dictionary
|
|
855
|
+
# to preserve the order of the modalities.
|
|
856
|
+
for modality in mm_input_by_modality:
|
|
857
|
+
multimodal_input = mm_input_by_modality[modality]
|
|
858
|
+
if modality == "image":
|
|
859
|
+
vision_embeddings = self._process_image_input(multimodal_input)
|
|
860
|
+
multimodal_embeddings += vision_embeddings
|
|
861
|
+
# if modality == "video":
|
|
862
|
+
# video_embeddings = self._process_video_input(multimodal_input)
|
|
863
|
+
# multimodal_embeddings += video_embeddings
|
|
864
|
+
|
|
865
|
+
return multimodal_embeddings
|
|
866
|
+
|
|
867
|
+
def get_input_embeddings(
|
|
868
|
+
self, input_ids: jax.Array,
|
|
869
|
+
multimodal_embeddings: Optional[MultiModalEmbeddings]
|
|
870
|
+
) -> jax.Array:
|
|
871
|
+
|
|
872
|
+
inputs_embeds = self.language_model.model.embed(input_ids)
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
if multimodal_embeddings is not None \
|
|
876
|
+
and len(multimodal_embeddings) != 0:
|
|
877
|
+
inputs_embeds = merge_multimodal_embeddings(
|
|
878
|
+
input_ids, inputs_embeds, multimodal_embeddings,
|
|
879
|
+
[self.config.image_token_id, self.config.video_token_id])
|
|
880
|
+
|
|
881
|
+
return inputs_embeds
|
|
882
|
+
|
|
883
|
+
def __call__(
|
|
884
|
+
self,
|
|
885
|
+
kv_caches: list[jax.Array],
|
|
886
|
+
input_ids: Optional[jax.Array],
|
|
887
|
+
attention_metadata: AttentionMetadata,
|
|
888
|
+
inputs_embeds: Optional[jax.Array] = None,
|
|
889
|
+
*args,
|
|
890
|
+
) -> tuple[list[jax.Array], jax.Array, List[jax.Array]]:
|
|
891
|
+
# The logic of choosing between input_ids and inputs_embeds is
|
|
892
|
+
# handled inside self.language_model.__call__
|
|
893
|
+
kv_caches, x, [] = self.language_model(
|
|
894
|
+
kv_caches=kv_caches,
|
|
895
|
+
input_ids=input_ids,
|
|
896
|
+
attention_metadata=attention_metadata,
|
|
897
|
+
inputs_embeds=inputs_embeds,
|
|
898
|
+
)
|
|
899
|
+
return kv_caches, x, []
|
|
900
|
+
|
|
901
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
902
|
+
return self.language_model.compute_logits(hidden_states)
|
|
903
|
+
|
|
904
|
+
def load_weights(self, rng_key: jax.Array) -> None:
|
|
905
|
+
self.rng = nnx.Rngs(rng_key)
|
|
906
|
+
self.language_model.rng = self.rng
|
|
907
|
+
|
|
908
|
+
# Key: path to a HF layer weight
|
|
909
|
+
# Value: a tuple of (path to a nnx layer weight, nnx weight sharding)
|
|
910
|
+
|
|
911
|
+
mappings = {
|
|
912
|
+
"model.embed_tokens": "language_model.model.embed.embedding",
|
|
913
|
+
"model.layers.*.input_layernorm":
|
|
914
|
+
"language_model.model.layers.*.input_layernorm.scale",
|
|
915
|
+
"model.layers.*.mlp.down_proj":
|
|
916
|
+
"language_model.model.layers.*.mlp.down_proj.kernel",
|
|
917
|
+
"model.layers.*.mlp.gate_proj":
|
|
918
|
+
"language_model.model.layers.*.mlp.gate_proj.kernel",
|
|
919
|
+
"model.layers.*.mlp.up_proj":
|
|
920
|
+
"language_model.model.layers.*.mlp.up_proj.kernel",
|
|
921
|
+
"model.layers.*.post_attention_layernorm":
|
|
922
|
+
"language_model.model.layers.*.post_attention_layernorm.scale",
|
|
923
|
+
"model.layers.*.self_attn.k_proj":
|
|
924
|
+
"language_model.model.layers.*.self_attn.k_proj.kernel",
|
|
925
|
+
"model.layers.*.self_attn.o_proj":
|
|
926
|
+
"language_model.model.layers.*.self_attn.o_proj.kernel",
|
|
927
|
+
"model.layers.*.self_attn.q_proj":
|
|
928
|
+
"language_model.model.layers.*.self_attn.q_proj.kernel",
|
|
929
|
+
"model.layers.*.self_attn.v_proj":
|
|
930
|
+
"language_model.model.layers.*.self_attn.v_proj.kernel",
|
|
931
|
+
"model.layers.*.self_attn.q_proj.bias":
|
|
932
|
+
"language_model.model.layers.*.self_attn.q_proj.bias",
|
|
933
|
+
"model.layers.*.self_attn.k_proj.bias":
|
|
934
|
+
"language_model.model.layers.*.self_attn.k_proj.bias",
|
|
935
|
+
"model.layers.*.self_attn.v_proj.bias":
|
|
936
|
+
"language_model.model.layers.*.self_attn.v_proj.bias",
|
|
937
|
+
"model.norm": "language_model.model.norm.scale",
|
|
938
|
+
"visual.blocks.*.attn.proj.bias": "visual.blocks.*.attn.proj.bias",
|
|
939
|
+
"visual.blocks.*.attn.proj": "visual.blocks.*.attn.proj.kernel",
|
|
940
|
+
"visual.blocks.*.attn.qkv.bias":
|
|
941
|
+
"visual.blocks.*.attn.qkv_proj.bias",
|
|
942
|
+
"visual.blocks.*.attn.qkv": "visual.blocks.*.attn.qkv_proj.kernel",
|
|
943
|
+
"visual.blocks.*.mlp.down_proj.bias":
|
|
944
|
+
"visual.blocks.*.mlp.down_proj.bias",
|
|
945
|
+
"visual.blocks.*.mlp.down_proj":
|
|
946
|
+
"visual.blocks.*.mlp.down_proj.kernel",
|
|
947
|
+
"visual.blocks.*.mlp.gate_proj.bias":
|
|
948
|
+
"visual.blocks.*.mlp.gate_proj.bias",
|
|
949
|
+
"visual.blocks.*.mlp.gate_proj":
|
|
950
|
+
"visual.blocks.*.mlp.gate_proj.kernel",
|
|
951
|
+
"visual.blocks.*.mlp.up_proj.bias":
|
|
952
|
+
"visual.blocks.*.mlp.up_proj.bias",
|
|
953
|
+
"visual.blocks.*.mlp.up_proj":
|
|
954
|
+
"visual.blocks.*.mlp.up_proj.kernel",
|
|
955
|
+
"visual.blocks.*.norm1": "visual.blocks.*.norm1.scale",
|
|
956
|
+
"visual.blocks.*.norm2": "visual.blocks.*.norm2.scale",
|
|
957
|
+
"visual.merger.ln_q": "visual.merger.ln_q.scale",
|
|
958
|
+
"visual.merger.mlp.0.bias": "visual.merger.mlp_fc1.bias",
|
|
959
|
+
"visual.merger.mlp.0": "visual.merger.mlp_fc1.kernel",
|
|
960
|
+
"visual.merger.mlp.2.bias": "visual.merger.mlp_fc2.bias",
|
|
961
|
+
"visual.merger.mlp.2": "visual.merger.mlp_fc2.kernel",
|
|
962
|
+
"visual.patch_embed.proj": "visual.patch_embed.proj.kernel",
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
# Add lm_head mapping only if it's not tied to embeddings
|
|
966
|
+
hf_config = self.vllm_config.model_config.hf_config
|
|
967
|
+
if not hf_config.tie_word_embeddings:
|
|
968
|
+
mappings.update({
|
|
969
|
+
"lm_head": "language_model.model.lm_head",
|
|
970
|
+
})
|
|
971
|
+
|
|
972
|
+
metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
|
|
973
|
+
load_hf_weights(vllm_config=self.vllm_config,
|
|
974
|
+
model=self,
|
|
975
|
+
metadata_map=metadata_map,
|
|
976
|
+
mesh=self.mesh)
|