nexaai 1.0.19rc6__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/mlx_backend/vlm/interface.py +21 -4
- nexaai/mlx_backend/vlm/main.py +6 -2
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/utils/manifest_utils.py +222 -15
- nexaai/utils/model_manager.py +83 -7
- nexaai/utils/model_types.py +2 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +35 -24
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1309 @@
|
|
|
1
|
+
# Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
import math
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import sys
|
|
13
|
+
|
|
14
|
+
curr_dir = os.path.dirname(os.path.abspath(__file__))
|
|
15
|
+
llm_common_dir = os.path.join(curr_dir, "llm_common")
|
|
16
|
+
sys.path.append(llm_common_dir)
|
|
17
|
+
|
|
18
|
+
# Try relative imports first, fallback to sys.path approach for Nuitka compatibility
|
|
19
|
+
try:
|
|
20
|
+
from .llm_common.base import (
|
|
21
|
+
BaseModelArgs,
|
|
22
|
+
create_attention_mask,
|
|
23
|
+
scaled_dot_product_attention,
|
|
24
|
+
)
|
|
25
|
+
from .llm_common.rope_utils import initialize_rope
|
|
26
|
+
except ImportError:
|
|
27
|
+
# Fallback for Nuitka compiled environment
|
|
28
|
+
from llm_common.base import (
|
|
29
|
+
BaseModelArgs,
|
|
30
|
+
create_attention_mask,
|
|
31
|
+
scaled_dot_product_attention,
|
|
32
|
+
)
|
|
33
|
+
from llm_common.rope_utils import initialize_rope
|
|
34
|
+
from switch_layers import SwitchGLU
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class VisionConfig:
|
|
39
|
+
hidden_size: int = 1152
|
|
40
|
+
intermediate_size: int = 4304
|
|
41
|
+
num_heads: int = 16
|
|
42
|
+
num_hidden_layers: int = 27
|
|
43
|
+
patch_size: int = 16
|
|
44
|
+
temporal_patch_size: int = 2
|
|
45
|
+
in_channels: int = 3
|
|
46
|
+
hidden_act: str = "gelu_pytorch_tanh"
|
|
47
|
+
spatial_merge_size: int = 2
|
|
48
|
+
out_hidden_size: int = 2048
|
|
49
|
+
num_position_embeddings: int = 2304
|
|
50
|
+
deepstack_visual_indexes: List[int] = None
|
|
51
|
+
|
|
52
|
+
def __post_init__(self):
|
|
53
|
+
if self.deepstack_visual_indexes is None:
|
|
54
|
+
self.deepstack_visual_indexes = [8, 16, 24]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class TextConfig(BaseModelArgs):
|
|
59
|
+
model_type: str = "qwen3_vl_moe_text"
|
|
60
|
+
hidden_size: int = 2048
|
|
61
|
+
num_hidden_layers: int = 48
|
|
62
|
+
intermediate_size: int = 6144
|
|
63
|
+
num_attention_heads: int = 32
|
|
64
|
+
num_key_value_heads: int = 4
|
|
65
|
+
rms_norm_eps: float = 1e-6
|
|
66
|
+
vocab_size: int = 152064
|
|
67
|
+
max_position_embeddings: int = 128000
|
|
68
|
+
rope_theta: float = 1000000.0
|
|
69
|
+
head_dim: int = 128
|
|
70
|
+
tie_word_embeddings: bool = False
|
|
71
|
+
attention_bias: bool = False
|
|
72
|
+
attention_dropout: float = 0.0
|
|
73
|
+
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
|
74
|
+
# MoE specific parameters
|
|
75
|
+
num_experts: int = 128
|
|
76
|
+
num_experts_per_tok: int = 8
|
|
77
|
+
moe_intermediate_size: int = 768
|
|
78
|
+
shared_expert_intermediate_size: int = 0
|
|
79
|
+
norm_topk_prob: bool = True
|
|
80
|
+
decoder_sparse_step: int = 1
|
|
81
|
+
max_window_layers: int = 48
|
|
82
|
+
sliding_window: int = 32768
|
|
83
|
+
mlp_only_layers: List[int] = None
|
|
84
|
+
use_qk_norm: bool = True
|
|
85
|
+
layer_types: List[str] = None
|
|
86
|
+
|
|
87
|
+
def __post_init__(self):
|
|
88
|
+
if self.rope_scaling is None:
|
|
89
|
+
self.rope_scaling = {
|
|
90
|
+
"mrope_interleaved": True,
|
|
91
|
+
"mrope_section": [24, 20, 20],
|
|
92
|
+
"rope_type": "default"
|
|
93
|
+
}
|
|
94
|
+
if self.mlp_only_layers is None:
|
|
95
|
+
self.mlp_only_layers = []
|
|
96
|
+
if self.layer_types is None:
|
|
97
|
+
# This would need to be populated based on the actual model architecture
|
|
98
|
+
self.layer_types = []
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class ModelArgs(BaseModelArgs):
|
|
103
|
+
vision_config: VisionConfig = None
|
|
104
|
+
text_config: TextConfig = None
|
|
105
|
+
image_token_id: int = 151655
|
|
106
|
+
vision_start_token_id: int = 151652
|
|
107
|
+
vision_end_token_id: int = 151653
|
|
108
|
+
|
|
109
|
+
def __post_init__(self):
|
|
110
|
+
if self.vision_config is None:
|
|
111
|
+
self.vision_config = VisionConfig()
|
|
112
|
+
if self.text_config is None:
|
|
113
|
+
self.text_config = TextConfig()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def rotate_half(x):
|
|
117
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
118
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
119
|
+
return mx.concatenate([-x2, x1], axis=-1)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def apply_rotary_pos_emb_vision(q, k, cos, sin):
|
|
123
|
+
cos = mx.expand_dims(cos, axis=-2)
|
|
124
|
+
sin = mx.expand_dims(sin, axis=-2)
|
|
125
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
126
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
127
|
+
return q_embed, k_embed
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
|
131
|
+
cos = mx.expand_dims(cos, axis=unsqueeze_dim)
|
|
132
|
+
sin = mx.expand_dims(sin, axis=unsqueeze_dim)
|
|
133
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
134
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
135
|
+
return q_embed, k_embed
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class VisionMLP(nn.Module):
|
|
139
|
+
def __init__(self, config: VisionConfig):
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.hidden_size = config.hidden_size
|
|
142
|
+
self.intermediate_size = config.intermediate_size
|
|
143
|
+
self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
|
|
144
|
+
self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
|
|
145
|
+
|
|
146
|
+
def __call__(self, hidden_state):
|
|
147
|
+
return self.linear_fc2(nn.gelu(self.linear_fc1(hidden_state)))
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class VisionPatchEmbed(nn.Module):
|
|
151
|
+
def __init__(self, config: VisionConfig):
|
|
152
|
+
super().__init__()
|
|
153
|
+
self.patch_size = config.patch_size
|
|
154
|
+
self.temporal_patch_size = config.temporal_patch_size
|
|
155
|
+
self.in_channels = config.in_channels
|
|
156
|
+
self.embed_dim = config.hidden_size
|
|
157
|
+
|
|
158
|
+
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
|
159
|
+
self.proj = nn.Conv3d(
|
|
160
|
+
self.in_channels,
|
|
161
|
+
self.embed_dim,
|
|
162
|
+
kernel_size=kernel_size,
|
|
163
|
+
stride=kernel_size,
|
|
164
|
+
bias=True
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
168
|
+
target_dtype = self.proj.weight.dtype
|
|
169
|
+
|
|
170
|
+
# Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
|
|
171
|
+
# This matches the PyTorch ground truth exactly
|
|
172
|
+
hidden_states = hidden_states.reshape(
|
|
173
|
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Convert to MLX format: [batch, temporal, height, width, channels]
|
|
177
|
+
hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
|
|
178
|
+
|
|
179
|
+
# Apply conv3d with target dtype and reshape to match PyTorch output
|
|
180
|
+
hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
|
|
181
|
+
|
|
182
|
+
return hidden_states
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class VisionRotaryEmbedding(nn.Module):
|
|
186
|
+
def __init__(self, dim: int, theta: float = 10000.0):
|
|
187
|
+
super().__init__()
|
|
188
|
+
# Don't store inv_freq as a parameter since it causes loading issues
|
|
189
|
+
self.dim = dim
|
|
190
|
+
self.theta = theta
|
|
191
|
+
|
|
192
|
+
def __call__(self, seqlen: int) -> mx.array:
|
|
193
|
+
# Compute inv_freq on the fly
|
|
194
|
+
inv_freq = 1.0 / (self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
|
|
195
|
+
seq = mx.arange(seqlen, dtype=inv_freq.dtype)
|
|
196
|
+
freqs = mx.outer(seq, inv_freq)
|
|
197
|
+
return freqs
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class VisionPatchMerger(nn.Module):
|
|
201
|
+
def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
|
|
202
|
+
super().__init__()
|
|
203
|
+
self.hidden_size = config.hidden_size * (config.spatial_merge_size ** 2)
|
|
204
|
+
self.use_postshuffle_norm = use_postshuffle_norm
|
|
205
|
+
|
|
206
|
+
norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
|
|
207
|
+
self.ln_q = nn.LayerNorm(norm_size, eps=1e-6)
|
|
208
|
+
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
|
|
209
|
+
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
|
|
210
|
+
|
|
211
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
212
|
+
if self.use_postshuffle_norm:
|
|
213
|
+
x = self.ln_q(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
|
|
214
|
+
else:
|
|
215
|
+
x = self.ln_q(x).reshape(-1, self.hidden_size)
|
|
216
|
+
|
|
217
|
+
x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
|
|
218
|
+
return x
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class VisionAttention(nn.Module):
|
|
222
|
+
def __init__(self, config: VisionConfig):
|
|
223
|
+
super().__init__()
|
|
224
|
+
self.dim = config.hidden_size
|
|
225
|
+
self.num_heads = config.num_heads
|
|
226
|
+
self.head_dim = self.dim // self.num_heads
|
|
227
|
+
self.scaling = self.head_dim ** -0.5
|
|
228
|
+
|
|
229
|
+
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
|
|
230
|
+
self.proj = nn.Linear(self.dim, self.dim)
|
|
231
|
+
|
|
232
|
+
def __call__(
|
|
233
|
+
self,
|
|
234
|
+
hidden_states: mx.array,
|
|
235
|
+
cu_seqlens: mx.array,
|
|
236
|
+
rotary_pos_emb: Optional[mx.array] = None,
|
|
237
|
+
position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
|
|
238
|
+
**kwargs,
|
|
239
|
+
) -> mx.array:
|
|
240
|
+
seq_length = hidden_states.shape[0]
|
|
241
|
+
qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
|
|
242
|
+
qkv = qkv.transpose(1, 0, 2, 3)
|
|
243
|
+
query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
|
|
244
|
+
|
|
245
|
+
cos, sin = position_embeddings
|
|
246
|
+
query_states, key_states = apply_rotary_pos_emb_vision(
|
|
247
|
+
query_states, key_states, cos, sin
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
query_states = query_states.transpose(1, 0, 2)
|
|
251
|
+
key_states = key_states.transpose(1, 0, 2)
|
|
252
|
+
value_states = value_states.transpose(1, 0, 2)
|
|
253
|
+
|
|
254
|
+
query_states = mx.expand_dims(query_states, axis=0)
|
|
255
|
+
key_states = mx.expand_dims(key_states, axis=0)
|
|
256
|
+
value_states = mx.expand_dims(value_states, axis=0)
|
|
257
|
+
|
|
258
|
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
259
|
+
|
|
260
|
+
split_indices = []
|
|
261
|
+
cumsum = 0
|
|
262
|
+
for length in lengths[:-1]:
|
|
263
|
+
cumsum += int(length)
|
|
264
|
+
split_indices.append(cumsum)
|
|
265
|
+
|
|
266
|
+
if split_indices:
|
|
267
|
+
q_splits = mx.split(query_states, split_indices, axis=1)
|
|
268
|
+
k_splits = mx.split(key_states, split_indices, axis=1)
|
|
269
|
+
v_splits = mx.split(value_states, split_indices, axis=1)
|
|
270
|
+
else:
|
|
271
|
+
q_splits = [query_states]
|
|
272
|
+
k_splits = [key_states]
|
|
273
|
+
v_splits = [value_states]
|
|
274
|
+
|
|
275
|
+
attn_outputs = []
|
|
276
|
+
for q, k, v in zip(q_splits, k_splits, v_splits):
|
|
277
|
+
attn_out = scaled_dot_product_attention(
|
|
278
|
+
q, k, v,
|
|
279
|
+
scale=self.scaling, mask=None, cache=None
|
|
280
|
+
)
|
|
281
|
+
attn_outputs.append(attn_out)
|
|
282
|
+
|
|
283
|
+
attn_output = mx.concatenate(attn_outputs, axis=1)
|
|
284
|
+
|
|
285
|
+
attn_output = attn_output[0].transpose(1, 0, 2)
|
|
286
|
+
attn_output = attn_output.reshape(seq_length, -1)
|
|
287
|
+
attn_output = self.proj(attn_output)
|
|
288
|
+
|
|
289
|
+
return attn_output
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class VisionBlock(nn.Module):
|
|
293
|
+
def __init__(self, config: VisionConfig):
|
|
294
|
+
super().__init__()
|
|
295
|
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
|
|
296
|
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
|
|
297
|
+
self.attn = VisionAttention(config)
|
|
298
|
+
self.mlp = VisionMLP(config)
|
|
299
|
+
|
|
300
|
+
def __call__(
|
|
301
|
+
self,
|
|
302
|
+
hidden_states: mx.array,
|
|
303
|
+
cu_seqlens: mx.array,
|
|
304
|
+
position_embeddings: Tuple[mx.array, mx.array],
|
|
305
|
+
) -> mx.array:
|
|
306
|
+
hidden_states = hidden_states + self.attn(
|
|
307
|
+
self.norm1(hidden_states),
|
|
308
|
+
cu_seqlens=cu_seqlens,
|
|
309
|
+
position_embeddings=position_embeddings,
|
|
310
|
+
)
|
|
311
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
|
312
|
+
return hidden_states
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class VisionModel(nn.Module):
|
|
316
|
+
def __init__(self, config: VisionConfig):
|
|
317
|
+
super().__init__()
|
|
318
|
+
self.config = config
|
|
319
|
+
self.spatial_merge_size = config.spatial_merge_size
|
|
320
|
+
self.patch_size = config.patch_size
|
|
321
|
+
|
|
322
|
+
self.patch_embed = VisionPatchEmbed(config)
|
|
323
|
+
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
|
|
324
|
+
self.num_grid_per_side = int(config.num_position_embeddings ** 0.5)
|
|
325
|
+
|
|
326
|
+
head_dim = config.hidden_size // config.num_heads
|
|
327
|
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
328
|
+
|
|
329
|
+
self.blocks = [VisionBlock(config) for _ in range(config.num_hidden_layers)]
|
|
330
|
+
self.merger = VisionPatchMerger(config, use_postshuffle_norm=False)
|
|
331
|
+
|
|
332
|
+
self.deepstack_visual_indexes = config.deepstack_visual_indexes
|
|
333
|
+
self.deepstack_merger_list = [
|
|
334
|
+
VisionPatchMerger(config, use_postshuffle_norm=True)
|
|
335
|
+
for _ in range(len(config.deepstack_visual_indexes))
|
|
336
|
+
]
|
|
337
|
+
|
|
338
|
+
def rot_pos_emb(self, grid_thw: mx.array) -> mx.array:
|
|
339
|
+
merge_size = self.spatial_merge_size
|
|
340
|
+
|
|
341
|
+
max_hw = int(grid_thw[:, 1:].max().item())
|
|
342
|
+
freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
|
|
343
|
+
|
|
344
|
+
pos_ids_parts = []
|
|
345
|
+
|
|
346
|
+
for i in range(grid_thw.shape[0]):
|
|
347
|
+
num_frames = int(grid_thw[i, 0].item())
|
|
348
|
+
height = int(grid_thw[i, 1].item())
|
|
349
|
+
width = int(grid_thw[i, 2].item())
|
|
350
|
+
|
|
351
|
+
merged_h, merged_w = height // merge_size, width // merge_size
|
|
352
|
+
|
|
353
|
+
block_rows = mx.arange(merged_h) # block row indices
|
|
354
|
+
block_cols = mx.arange(merged_w) # block col indices
|
|
355
|
+
intra_row = mx.arange(merge_size) # intra-block row offsets
|
|
356
|
+
intra_col = mx.arange(merge_size) # intra-block col offsets
|
|
357
|
+
|
|
358
|
+
# Compute full-resolution positions using broadcasting
|
|
359
|
+
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
|
360
|
+
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
|
361
|
+
|
|
362
|
+
row_idx = mx.broadcast_to(row_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
|
|
363
|
+
col_idx = mx.broadcast_to(col_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
|
|
364
|
+
|
|
365
|
+
coords = mx.stack([row_idx, col_idx], axis=-1)
|
|
366
|
+
|
|
367
|
+
if num_frames > 1:
|
|
368
|
+
coords = mx.tile(coords, (num_frames, 1))
|
|
369
|
+
|
|
370
|
+
pos_ids_parts.append(coords)
|
|
371
|
+
|
|
372
|
+
# Concatenate all coordinate parts
|
|
373
|
+
pos_ids = mx.concatenate(pos_ids_parts, axis=0)
|
|
374
|
+
|
|
375
|
+
embeddings = freq_table[pos_ids] # lookup rotary embeddings
|
|
376
|
+
embeddings = embeddings.reshape(embeddings.shape[0], -1)
|
|
377
|
+
return embeddings
|
|
378
|
+
|
|
379
|
+
def fast_pos_embed_interpolate(self, grid_thw: mx.array):
|
|
380
|
+
patch_pos_embeds = []
|
|
381
|
+
|
|
382
|
+
for i in range(grid_thw.shape[0]):
|
|
383
|
+
t = int(grid_thw[i, 0].item())
|
|
384
|
+
h = int(grid_thw[i, 1].item())
|
|
385
|
+
w = int(grid_thw[i, 2].item())
|
|
386
|
+
|
|
387
|
+
# Simple position embedding interpolation
|
|
388
|
+
h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
|
|
389
|
+
w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
|
|
390
|
+
|
|
391
|
+
h_idxs_floor = mx.floor(h_idxs).astype(mx.int32)
|
|
392
|
+
w_idxs_floor = mx.floor(w_idxs).astype(mx.int32)
|
|
393
|
+
h_idxs_ceil = mx.minimum(h_idxs_floor + 1, self.num_grid_per_side - 1)
|
|
394
|
+
w_idxs_ceil = mx.minimum(w_idxs_floor + 1, self.num_grid_per_side - 1)
|
|
395
|
+
|
|
396
|
+
dh = h_idxs - h_idxs_floor.astype(mx.float32)
|
|
397
|
+
dw = w_idxs - w_idxs_floor.astype(mx.float32)
|
|
398
|
+
|
|
399
|
+
base_h = h_idxs_floor * self.num_grid_per_side
|
|
400
|
+
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
|
|
401
|
+
|
|
402
|
+
# Compute bilinear interpolation indices and weights
|
|
403
|
+
indices_tl = (base_h[:, None] + w_idxs_floor[None, :]).reshape(-1)
|
|
404
|
+
indices_tr = (base_h[:, None] + w_idxs_ceil[None, :]).reshape(-1)
|
|
405
|
+
indices_bl = (base_h_ceil[:, None] + w_idxs_floor[None, :]).reshape(-1)
|
|
406
|
+
indices_br = (base_h_ceil[:, None] + w_idxs_ceil[None, :]).reshape(-1)
|
|
407
|
+
|
|
408
|
+
weights_tl = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
|
|
409
|
+
weights_tr = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
|
|
410
|
+
weights_bl = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
|
|
411
|
+
weights_br = (dh[:, None] * dw[None, :]).reshape(-1)
|
|
412
|
+
|
|
413
|
+
# Get embeddings and interpolate
|
|
414
|
+
pos_embed_tl = self.pos_embed(indices_tl) * weights_tl[:, None]
|
|
415
|
+
pos_embed_tr = self.pos_embed(indices_tr) * weights_tr[:, None]
|
|
416
|
+
pos_embed_bl = self.pos_embed(indices_bl) * weights_bl[:, None]
|
|
417
|
+
pos_embed_br = self.pos_embed(indices_br) * weights_br[:, None]
|
|
418
|
+
|
|
419
|
+
pos_embed = pos_embed_tl + pos_embed_tr + pos_embed_bl + pos_embed_br
|
|
420
|
+
|
|
421
|
+
# Repeat for temporal dimension and apply spatial merging
|
|
422
|
+
pos_embed = mx.tile(pos_embed, (t, 1))
|
|
423
|
+
|
|
424
|
+
# Apply spatial merging pattern
|
|
425
|
+
merge_size = self.config.spatial_merge_size
|
|
426
|
+
pos_embed = pos_embed.reshape(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
|
427
|
+
pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
|
|
428
|
+
pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
|
|
429
|
+
|
|
430
|
+
patch_pos_embeds.append(pos_embed)
|
|
431
|
+
|
|
432
|
+
return mx.concatenate(patch_pos_embeds, axis=0)
|
|
433
|
+
|
|
434
|
+
def __call__(self, hidden_states: mx.array, grid_thw: mx.array) -> Tuple[mx.array, List[mx.array]]:
|
|
435
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
436
|
+
|
|
437
|
+
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
|
438
|
+
hidden_states = hidden_states + pos_embeds
|
|
439
|
+
|
|
440
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
441
|
+
seq_len = hidden_states.shape[0]
|
|
442
|
+
|
|
443
|
+
emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
|
|
444
|
+
position_embeddings = (mx.cos(emb), mx.sin(emb))
|
|
445
|
+
|
|
446
|
+
# Create cumulative sequence lengths (following HuggingFace implementation)
|
|
447
|
+
# torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
|
|
448
|
+
seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
|
|
449
|
+
seq_lens = []
|
|
450
|
+
for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
|
|
451
|
+
seq_lens.extend([seq_len] * int(repeats))
|
|
452
|
+
seq_lens = mx.array(seq_lens)
|
|
453
|
+
|
|
454
|
+
# Then compute cumulative sum
|
|
455
|
+
cu_seqlens = mx.cumsum(seq_lens)
|
|
456
|
+
# Pad with 0 at the beginning
|
|
457
|
+
cu_seqlens = mx.concatenate([mx.array([0]), cu_seqlens])
|
|
458
|
+
|
|
459
|
+
deepstack_feature_lists = []
|
|
460
|
+
for layer_num, blk in enumerate(self.blocks):
|
|
461
|
+
hidden_states = blk(
|
|
462
|
+
hidden_states,
|
|
463
|
+
cu_seqlens=cu_seqlens,
|
|
464
|
+
position_embeddings=position_embeddings,
|
|
465
|
+
)
|
|
466
|
+
if layer_num in self.deepstack_visual_indexes:
|
|
467
|
+
idx = self.deepstack_visual_indexes.index(layer_num)
|
|
468
|
+
deepstack_feature = self.deepstack_merger_list[idx](hidden_states)
|
|
469
|
+
deepstack_feature_lists.append(deepstack_feature)
|
|
470
|
+
|
|
471
|
+
hidden_states = self.merger(hidden_states)
|
|
472
|
+
return hidden_states, deepstack_feature_lists
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
class TextRotaryEmbedding(nn.Module):
|
|
476
|
+
def __init__(self, config: TextConfig):
|
|
477
|
+
super().__init__()
|
|
478
|
+
self.config = config
|
|
479
|
+
self.max_seq_len_cached = config.max_position_embeddings
|
|
480
|
+
self.original_max_seq_len = config.max_position_embeddings
|
|
481
|
+
|
|
482
|
+
# MRoPE configuration
|
|
483
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
484
|
+
self.rope_type = config.rope_scaling.get("rope_type", "default")
|
|
485
|
+
self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
|
|
486
|
+
else:
|
|
487
|
+
self.rope_type = "default"
|
|
488
|
+
self.mrope_section = [24, 20, 20]
|
|
489
|
+
|
|
490
|
+
# Store parameters for computing inv_freq on the fly
|
|
491
|
+
self.head_dim = config.head_dim
|
|
492
|
+
self.theta = config.rope_theta
|
|
493
|
+
|
|
494
|
+
# Attention scaling (simplified - may need adjustment based on actual config)
|
|
495
|
+
self.attention_scaling = 1.0
|
|
496
|
+
|
|
497
|
+
def _get_inv_freq(self):
|
|
498
|
+
"""Compute inverse frequencies on the fly"""
|
|
499
|
+
inv_freq = 1.0 / (self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim))
|
|
500
|
+
# Expand for 3 dimensions (T, H, W)
|
|
501
|
+
return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
|
|
502
|
+
|
|
503
|
+
def apply_interleaved_mrope(self, freqs, mrope_section):
|
|
504
|
+
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
|
505
|
+
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
|
506
|
+
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
|
507
|
+
args:
|
|
508
|
+
x: (3, bs, seq_len, head_dim // 2)
|
|
509
|
+
mrope_section: (3,)
|
|
510
|
+
returns:
|
|
511
|
+
x_t: (bs, seq_len, head_dim // 2)
|
|
512
|
+
"""
|
|
513
|
+
freqs_t = freqs[0] # just overwrite the first dimension T
|
|
514
|
+
for dim, offset in enumerate((1, 2), start=1): # H, W
|
|
515
|
+
length = mrope_section[dim] * 3
|
|
516
|
+
idx = slice(offset, length, 3)
|
|
517
|
+
freqs_t[..., idx] = freqs[dim, ..., idx]
|
|
518
|
+
return freqs_t
|
|
519
|
+
|
|
520
|
+
def __call__(self, x: mx.array, position_ids: mx.array) -> mx.array:
|
|
521
|
+
"""
|
|
522
|
+
Args:
|
|
523
|
+
x: Input tensor for dtype reference
|
|
524
|
+
position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
cos, sin: Cosine and sine embeddings
|
|
528
|
+
"""
|
|
529
|
+
# Handle 2D position_ids by expanding to 3D for MRoPE
|
|
530
|
+
if position_ids.ndim == 2:
|
|
531
|
+
position_ids = mx.broadcast_to(position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1]))
|
|
532
|
+
|
|
533
|
+
batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
|
|
534
|
+
|
|
535
|
+
# Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
|
|
536
|
+
inv_freq_expanded = mx.broadcast_to(
|
|
537
|
+
self._get_inv_freq()[:, None, None, :],
|
|
538
|
+
(3, batch_size, 1, self._get_inv_freq().shape[-1])
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
|
|
542
|
+
position_ids_expanded = position_ids[..., None].astype(mx.float32)
|
|
543
|
+
|
|
544
|
+
# Compute frequencies: (3, batch_size, seq_len, dim//2)
|
|
545
|
+
freqs = inv_freq_expanded * position_ids_expanded
|
|
546
|
+
|
|
547
|
+
# Apply interleaved MRoPE
|
|
548
|
+
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
|
549
|
+
|
|
550
|
+
# Create embeddings
|
|
551
|
+
emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
|
|
552
|
+
cos = mx.cos(emb) * self.attention_scaling
|
|
553
|
+
sin = mx.sin(emb) * self.attention_scaling
|
|
554
|
+
|
|
555
|
+
return cos.astype(x.dtype), sin.astype(x.dtype)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class TextAttention(nn.Module):
|
|
559
|
+
def __init__(self, config: TextConfig, layer_idx: int):
|
|
560
|
+
super().__init__()
|
|
561
|
+
self.config = config
|
|
562
|
+
self.layer_idx = layer_idx
|
|
563
|
+
|
|
564
|
+
dim = config.hidden_size
|
|
565
|
+
self.n_heads = config.num_attention_heads
|
|
566
|
+
self.n_kv_heads = config.num_key_value_heads
|
|
567
|
+
self.head_dim = config.head_dim
|
|
568
|
+
self.scale = self.head_dim ** -0.5
|
|
569
|
+
|
|
570
|
+
self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
|
|
571
|
+
self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
|
|
572
|
+
self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
|
|
573
|
+
self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=config.attention_bias)
|
|
574
|
+
|
|
575
|
+
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
576
|
+
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
577
|
+
|
|
578
|
+
# Initialize rope directly
|
|
579
|
+
self.rope = initialize_rope(
|
|
580
|
+
config.head_dim,
|
|
581
|
+
base=config.rope_theta,
|
|
582
|
+
traditional=False,
|
|
583
|
+
scaling_config=config.rope_scaling,
|
|
584
|
+
max_position_embeddings=config.max_position_embeddings,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
def __call__(
|
|
588
|
+
self,
|
|
589
|
+
hidden_states: mx.array,
|
|
590
|
+
attention_mask: Optional[mx.array] = None,
|
|
591
|
+
cache: Optional[Any] = None,
|
|
592
|
+
cos: Optional[mx.array] = None,
|
|
593
|
+
sin: Optional[mx.array] = None,
|
|
594
|
+
rope_deltas: Optional[mx.array] = None,
|
|
595
|
+
) -> Tuple[mx.array, Optional[mx.array]]:
|
|
596
|
+
B, L, D = hidden_states.shape
|
|
597
|
+
|
|
598
|
+
queries = self.q_proj(hidden_states).reshape(B, L, self.n_heads, -1)
|
|
599
|
+
keys = self.k_proj(hidden_states).reshape(B, L, self.n_kv_heads, -1)
|
|
600
|
+
values = self.v_proj(hidden_states).reshape(B, L, self.n_kv_heads, -1)
|
|
601
|
+
|
|
602
|
+
queries = self.q_norm(queries).transpose(0, 2, 1, 3)
|
|
603
|
+
keys = self.k_norm(keys).transpose(0, 2, 1, 3)
|
|
604
|
+
values = values.transpose(0, 2, 1, 3)
|
|
605
|
+
|
|
606
|
+
# Apply rope directly to queries and keys
|
|
607
|
+
if cos is not None and sin is not None:
|
|
608
|
+
queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
|
|
609
|
+
if cache is not None:
|
|
610
|
+
keys, values = cache.update_and_fetch(keys, values)
|
|
611
|
+
else:
|
|
612
|
+
if cache is not None:
|
|
613
|
+
queries = self.rope(queries, offset=cache.offset+rope_deltas)
|
|
614
|
+
keys = self.rope(keys, offset=cache.offset+rope_deltas)
|
|
615
|
+
keys, values = cache.update_and_fetch(keys, values)
|
|
616
|
+
else:
|
|
617
|
+
queries = self.rope(queries)
|
|
618
|
+
keys = self.rope(keys)
|
|
619
|
+
|
|
620
|
+
output = scaled_dot_product_attention(
|
|
621
|
+
queries, keys, values, cache=cache, scale=self.scale, mask=attention_mask
|
|
622
|
+
)
|
|
623
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
624
|
+
return self.o_proj(output), None
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
class TextMLP(nn.Module):
|
|
628
|
+
def __init__(self, config: TextConfig):
|
|
629
|
+
super().__init__()
|
|
630
|
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
|
631
|
+
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
|
632
|
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
|
633
|
+
|
|
634
|
+
def __call__(self, x):
|
|
635
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
# Add this custom MoE implementation to replace SwitchGLU usage
|
|
639
|
+
|
|
640
|
+
class TextMoEExperts(nn.Module):
|
|
641
|
+
def __init__(self, config: TextConfig):
|
|
642
|
+
super().__init__()
|
|
643
|
+
# Use the optimized SwitchGLU implementation for efficient expert computation
|
|
644
|
+
self.switch_glu = SwitchGLU(
|
|
645
|
+
input_dims=config.hidden_size,
|
|
646
|
+
hidden_dims=config.moe_intermediate_size,
|
|
647
|
+
num_experts=config.num_experts,
|
|
648
|
+
activation=nn.SiLU(),
|
|
649
|
+
bias=False
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
def __call__(self, hidden_states: mx.array, routing_weights: mx.array, router_indices: mx.array) -> mx.array:
|
|
653
|
+
# Use the efficient SwitchGLU implementation
|
|
654
|
+
# SwitchGLU handles the expert routing internally and is highly optimized
|
|
655
|
+
expert_output = self.switch_glu(hidden_states, router_indices)
|
|
656
|
+
|
|
657
|
+
# Apply routing weights and sum over experts (top_k dimension)
|
|
658
|
+
weighted_output = expert_output * mx.expand_dims(routing_weights, -1)
|
|
659
|
+
final_output = mx.sum(weighted_output, axis=-2)
|
|
660
|
+
|
|
661
|
+
return final_output
|
|
662
|
+
|
|
663
|
+
class TextSparseMoeBlock(nn.Module):
|
|
664
|
+
def __init__(self, config: TextConfig):
|
|
665
|
+
super().__init__()
|
|
666
|
+
self.hidden_size = config.hidden_size
|
|
667
|
+
self.num_experts = config.num_experts
|
|
668
|
+
self.top_k = config.num_experts_per_tok
|
|
669
|
+
self.norm_topk_prob = config.norm_topk_prob
|
|
670
|
+
|
|
671
|
+
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
|
|
672
|
+
self.experts = TextMoEExperts(config)
|
|
673
|
+
|
|
674
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
675
|
+
batch_size, sequence_length, hidden_dim = x.shape
|
|
676
|
+
x_flat = x.reshape(-1, hidden_dim)
|
|
677
|
+
|
|
678
|
+
# Router computation
|
|
679
|
+
router_logits = self.gate(x_flat)
|
|
680
|
+
routing_weights = mx.softmax(router_logits, axis=-1, precise=True)
|
|
681
|
+
|
|
682
|
+
# Top-k selection
|
|
683
|
+
router_indices = mx.argpartition(-routing_weights, kth=self.top_k - 1, axis=-1)[..., :self.top_k]
|
|
684
|
+
routing_weights = mx.take_along_axis(routing_weights, router_indices, axis=-1)
|
|
685
|
+
|
|
686
|
+
if self.norm_topk_prob:
|
|
687
|
+
routing_weights = routing_weights / mx.sum(routing_weights, axis=-1, keepdims=True)
|
|
688
|
+
|
|
689
|
+
# Expert computation
|
|
690
|
+
final_hidden_states = self.experts(x, routing_weights, router_indices)
|
|
691
|
+
|
|
692
|
+
return final_hidden_states
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
class TextDecoderLayer(nn.Module):
|
|
696
|
+
def __init__(self, config: TextConfig, layer_idx: int):
|
|
697
|
+
super().__init__()
|
|
698
|
+
self.hidden_size = config.hidden_size
|
|
699
|
+
self.self_attn = TextAttention(config, layer_idx)
|
|
700
|
+
|
|
701
|
+
# Determine if this layer should use MoE
|
|
702
|
+
use_moe = (
|
|
703
|
+
layer_idx not in config.mlp_only_layers and
|
|
704
|
+
config.num_experts > 0 and
|
|
705
|
+
(layer_idx + 1) % config.decoder_sparse_step == 0
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
if use_moe:
|
|
709
|
+
self.mlp = TextSparseMoeBlock(config)
|
|
710
|
+
else:
|
|
711
|
+
self.mlp = TextMLP(config)
|
|
712
|
+
|
|
713
|
+
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
714
|
+
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
715
|
+
|
|
716
|
+
def __call__(
|
|
717
|
+
self,
|
|
718
|
+
hidden_states: mx.array,
|
|
719
|
+
attention_mask: Optional[mx.array] = None,
|
|
720
|
+
cache: Optional[Any] = None,
|
|
721
|
+
cos: Optional[mx.array] = None,
|
|
722
|
+
sin: Optional[mx.array] = None,
|
|
723
|
+
rope_deltas: Optional[mx.array] = None,
|
|
724
|
+
) -> mx.array:
|
|
725
|
+
residual = hidden_states
|
|
726
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
727
|
+
|
|
728
|
+
hidden_states, _ = self.self_attn(
|
|
729
|
+
hidden_states=hidden_states,
|
|
730
|
+
attention_mask=attention_mask,
|
|
731
|
+
cache=cache,
|
|
732
|
+
cos=cos,
|
|
733
|
+
sin=sin,
|
|
734
|
+
rope_deltas=rope_deltas,
|
|
735
|
+
)
|
|
736
|
+
hidden_states = residual + hidden_states
|
|
737
|
+
residual = hidden_states
|
|
738
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
739
|
+
hidden_states = self.mlp(hidden_states)
|
|
740
|
+
hidden_states = residual + hidden_states
|
|
741
|
+
return hidden_states
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
class TextModel(nn.Module):
|
|
745
|
+
def __init__(self, config: TextConfig):
|
|
746
|
+
super().__init__()
|
|
747
|
+
self.config = config
|
|
748
|
+
self.vocab_size = config.vocab_size
|
|
749
|
+
|
|
750
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
751
|
+
self.layers = [
|
|
752
|
+
TextDecoderLayer(config, layer_idx)
|
|
753
|
+
for layer_idx in range(config.num_hidden_layers)
|
|
754
|
+
]
|
|
755
|
+
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
756
|
+
self.rotary_emb = TextRotaryEmbedding(config)
|
|
757
|
+
|
|
758
|
+
def _deepstack_process(
|
|
759
|
+
self,
|
|
760
|
+
hidden_states: mx.array,
|
|
761
|
+
visual_pos_masks: mx.array,
|
|
762
|
+
deepstack_visual_embeds: mx.array,
|
|
763
|
+
) -> mx.array:
|
|
764
|
+
if visual_pos_masks is None or deepstack_visual_embeds is None:
|
|
765
|
+
return hidden_states
|
|
766
|
+
B, L, D = hidden_states.shape
|
|
767
|
+
mask_flat = visual_pos_masks.astype(mx.int32).reshape(-1)
|
|
768
|
+
idx_flat = mx.cumsum(mask_flat, axis=0) - 1
|
|
769
|
+
N = deepstack_visual_embeds.shape[0]
|
|
770
|
+
idx_flat = mx.maximum(idx_flat, 0)
|
|
771
|
+
eq = (idx_flat[:, None] == mx.arange(N)[None, :]).astype(hidden_states.dtype)
|
|
772
|
+
add_flat = eq @ deepstack_visual_embeds.astype(hidden_states.dtype)
|
|
773
|
+
add_flat = add_flat * mask_flat[:, None].astype(hidden_states.dtype)
|
|
774
|
+
add = add_flat.reshape(B, L, D)
|
|
775
|
+
return hidden_states + add
|
|
776
|
+
|
|
777
|
+
def __call__(
|
|
778
|
+
self,
|
|
779
|
+
input_ids: Optional[mx.array] = None,
|
|
780
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
781
|
+
attention_mask: Optional[mx.array] = None,
|
|
782
|
+
cache=None,
|
|
783
|
+
visual_pos_masks: Optional[mx.array] = None,
|
|
784
|
+
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
785
|
+
cos: Optional[mx.array] = None,
|
|
786
|
+
sin: Optional[mx.array] = None,
|
|
787
|
+
rope_deltas: Optional[mx.array] = None,
|
|
788
|
+
):
|
|
789
|
+
if inputs_embeds is None:
|
|
790
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
|
791
|
+
|
|
792
|
+
hidden_states = inputs_embeds
|
|
793
|
+
|
|
794
|
+
if attention_mask is None:
|
|
795
|
+
attention_mask = create_attention_mask(hidden_states, cache, return_array=True)
|
|
796
|
+
|
|
797
|
+
if cache is None:
|
|
798
|
+
cache = [None] * len(self.layers)
|
|
799
|
+
|
|
800
|
+
for layer_idx, (decoder_layer, c) in enumerate(zip(self.layers, cache)):
|
|
801
|
+
hidden_states = decoder_layer(
|
|
802
|
+
hidden_states,
|
|
803
|
+
attention_mask=attention_mask,
|
|
804
|
+
cache=c,
|
|
805
|
+
cos=cos,
|
|
806
|
+
sin=sin,
|
|
807
|
+
rope_deltas=rope_deltas,
|
|
808
|
+
)
|
|
809
|
+
if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
|
|
810
|
+
hidden_states = self._deepstack_process(hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx])
|
|
811
|
+
hidden_states = self.norm(hidden_states)
|
|
812
|
+
return hidden_states
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
# Standalone Vision Model
|
|
816
|
+
class VEGModel(nn.Module):
|
|
817
|
+
def __init__(self, vision_config: VisionConfig):
|
|
818
|
+
super().__init__()
|
|
819
|
+
self.config = vision_config
|
|
820
|
+
self.visual = VisionModel(vision_config)
|
|
821
|
+
|
|
822
|
+
def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
|
|
823
|
+
return self.visual(pixel_values, image_grid_thw)
|
|
824
|
+
|
|
825
|
+
def sanitize(self, weights):
|
|
826
|
+
sanitized = {}
|
|
827
|
+
for k, v in weights.items():
|
|
828
|
+
if 'visual.' in k:
|
|
829
|
+
# Remove prefixes to match our model structure
|
|
830
|
+
clean_key = k.replace('model.visual.', '').replace('visual.', '')
|
|
831
|
+
sanitized[f'visual.{clean_key}'] = v
|
|
832
|
+
return sanitized
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
# Pure LLM Model (no vision components)
|
|
836
|
+
class LLMModel(nn.Module):
|
|
837
|
+
def __init__(self, text_config: TextConfig):
|
|
838
|
+
super().__init__()
|
|
839
|
+
self.args = text_config
|
|
840
|
+
self.config = text_config
|
|
841
|
+
self.language_model = TextModel(text_config)
|
|
842
|
+
if not text_config.tie_word_embeddings:
|
|
843
|
+
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
|
|
844
|
+
|
|
845
|
+
def get_rope_index(
|
|
846
|
+
self,
|
|
847
|
+
input_ids: Optional[mx.array] = None,
|
|
848
|
+
image_grid_thw: Optional[mx.array] = None,
|
|
849
|
+
attention_mask: Optional[mx.array] = None,
|
|
850
|
+
) -> Tuple[mx.array, mx.array]:
|
|
851
|
+
"""Simplified version for images only (no video support)."""
|
|
852
|
+
|
|
853
|
+
spatial_merge_size = 2
|
|
854
|
+
image_token_id = 151655
|
|
855
|
+
vision_start_token_id = 151652
|
|
856
|
+
mrope_position_deltas = []
|
|
857
|
+
|
|
858
|
+
if input_ids is not None and image_grid_thw is not None:
|
|
859
|
+
total_input_ids = input_ids
|
|
860
|
+
if attention_mask is None:
|
|
861
|
+
attention_mask = mx.ones_like(total_input_ids)
|
|
862
|
+
|
|
863
|
+
batch_size, seq_len = input_ids.shape
|
|
864
|
+
position_ids_list = []
|
|
865
|
+
image_index = 0
|
|
866
|
+
|
|
867
|
+
for i in range(batch_size):
|
|
868
|
+
input_ids_seq = total_input_ids[i]
|
|
869
|
+
mask_seq = attention_mask[i]
|
|
870
|
+
|
|
871
|
+
# Use mask to get valid length
|
|
872
|
+
valid_length = int(mx.sum(mask_seq).item())
|
|
873
|
+
input_ids_seq = input_ids_seq[:valid_length]
|
|
874
|
+
|
|
875
|
+
image_nums = 0
|
|
876
|
+
# Find vision start tokens by iterating through the sequence
|
|
877
|
+
vision_start_positions = []
|
|
878
|
+
for pos in range(input_ids_seq.shape[0]):
|
|
879
|
+
if input_ids_seq[pos].item() == vision_start_token_id:
|
|
880
|
+
vision_start_positions.append(pos)
|
|
881
|
+
|
|
882
|
+
if len(vision_start_positions) > 0:
|
|
883
|
+
for pos in vision_start_positions:
|
|
884
|
+
if pos + 1 < input_ids_seq.shape[0]:
|
|
885
|
+
if input_ids_seq[pos + 1].item() == image_token_id:
|
|
886
|
+
image_nums += 1
|
|
887
|
+
|
|
888
|
+
input_tokens = input_ids_seq.tolist()
|
|
889
|
+
llm_pos_ids_list = []
|
|
890
|
+
st = 0
|
|
891
|
+
remain_images = image_nums
|
|
892
|
+
|
|
893
|
+
for _ in range(image_nums):
|
|
894
|
+
ed_image = input_tokens.index(image_token_id, st)
|
|
895
|
+
|
|
896
|
+
t = image_grid_thw[image_index, 0].item()
|
|
897
|
+
h = image_grid_thw[image_index, 1].item()
|
|
898
|
+
w = image_grid_thw[image_index, 2].item()
|
|
899
|
+
image_index += 1
|
|
900
|
+
remain_images -= 1
|
|
901
|
+
ed = ed_image
|
|
902
|
+
|
|
903
|
+
llm_grid_t = int(t)
|
|
904
|
+
llm_grid_h = int(h) // spatial_merge_size
|
|
905
|
+
llm_grid_w = int(w) // spatial_merge_size
|
|
906
|
+
text_len = ed - st
|
|
907
|
+
|
|
908
|
+
st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
909
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
910
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
911
|
+
llm_pos_ids_list.append(text_pos)
|
|
912
|
+
|
|
913
|
+
# t_index is always 0 because llm_grid_t is always 1 for images
|
|
914
|
+
t_index = mx.arange(llm_grid_t).reshape(-1, 1)
|
|
915
|
+
t_index = mx.broadcast_to(t_index, (llm_grid_t, llm_grid_h * llm_grid_w)).reshape(-1)
|
|
916
|
+
|
|
917
|
+
h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
|
|
918
|
+
h_index = mx.broadcast_to(h_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
|
|
919
|
+
|
|
920
|
+
w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
|
|
921
|
+
w_index = mx.broadcast_to(w_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
|
|
922
|
+
|
|
923
|
+
vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
|
|
924
|
+
llm_pos_ids_list.append(vision_pos)
|
|
925
|
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
926
|
+
|
|
927
|
+
if st < len(input_tokens):
|
|
928
|
+
st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
929
|
+
text_len = len(input_tokens) - st
|
|
930
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
931
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
932
|
+
llm_pos_ids_list.append(text_pos)
|
|
933
|
+
|
|
934
|
+
llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
|
935
|
+
|
|
936
|
+
# Create position_ids for this batch item, pad to seq_len
|
|
937
|
+
batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
|
|
938
|
+
valid_length = min(seq_len, llm_positions.shape[1])
|
|
939
|
+
|
|
940
|
+
# Create new arrays for each dimension
|
|
941
|
+
pos_dim0 = mx.concatenate([llm_positions[0, :valid_length],
|
|
942
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
|
|
943
|
+
pos_dim1 = mx.concatenate([llm_positions[1, :valid_length],
|
|
944
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
|
|
945
|
+
pos_dim2 = mx.concatenate([llm_positions[2, :valid_length],
|
|
946
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
|
|
947
|
+
|
|
948
|
+
batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
|
|
949
|
+
position_ids_list.append(batch_position_ids)
|
|
950
|
+
|
|
951
|
+
mrope_position_deltas.append(llm_positions.max().item() + 1 - len(total_input_ids[i]))
|
|
952
|
+
|
|
953
|
+
# Stack all batch position_ids
|
|
954
|
+
position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
|
|
955
|
+
mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
|
|
956
|
+
return position_ids, mrope_position_deltas
|
|
957
|
+
else:
|
|
958
|
+
if attention_mask is not None:
|
|
959
|
+
position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
|
|
960
|
+
position_ids = mx.where(attention_mask == 0, 1, position_ids)
|
|
961
|
+
position_ids = mx.expand_dims(position_ids, axis=0)
|
|
962
|
+
position_ids = mx.broadcast_to(position_ids, (3, position_ids.shape[1], position_ids.shape[2]))
|
|
963
|
+
max_position_ids = mx.max(mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True)
|
|
964
|
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
965
|
+
else:
|
|
966
|
+
seq_len = input_ids.shape[1]
|
|
967
|
+
batch_size = input_ids.shape[0]
|
|
968
|
+
position_ids = mx.arange(seq_len).reshape(1, 1, -1)
|
|
969
|
+
position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
|
|
970
|
+
mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
|
|
971
|
+
|
|
972
|
+
return position_ids, mrope_position_deltas
|
|
973
|
+
|
|
974
|
+
def __call__(
|
|
975
|
+
self,
|
|
976
|
+
inputs: mx.array = None,
|
|
977
|
+
mask: mx.array = None,
|
|
978
|
+
cache=None,
|
|
979
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
980
|
+
visual_pos_masks: Optional[mx.array] = None,
|
|
981
|
+
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
982
|
+
cos: Optional[mx.array] = None,
|
|
983
|
+
sin: Optional[mx.array] = None,
|
|
984
|
+
rope_deltas: Optional[mx.array] = None,
|
|
985
|
+
):
|
|
986
|
+
out = self.language_model(
|
|
987
|
+
input_ids=inputs,
|
|
988
|
+
inputs_embeds=inputs_embeds,
|
|
989
|
+
attention_mask=mask,
|
|
990
|
+
cache=cache,
|
|
991
|
+
visual_pos_masks=visual_pos_masks,
|
|
992
|
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
|
993
|
+
cos=cos,
|
|
994
|
+
sin=sin,
|
|
995
|
+
rope_deltas=rope_deltas,
|
|
996
|
+
)
|
|
997
|
+
if self.args.tie_word_embeddings:
|
|
998
|
+
return self.language_model.embed_tokens.as_linear(out)
|
|
999
|
+
else:
|
|
1000
|
+
return self.lm_head(out)
|
|
1001
|
+
|
|
1002
|
+
def sanitize(self, weights):
|
|
1003
|
+
sanitized = {}
|
|
1004
|
+
for k, v in weights.items():
|
|
1005
|
+
if not ('visual.' in k):
|
|
1006
|
+
# Handle key mapping from combined model to LLM-only model
|
|
1007
|
+
clean_key = k
|
|
1008
|
+
|
|
1009
|
+
# Remove model. prefix if present
|
|
1010
|
+
if clean_key.startswith('model.'):
|
|
1011
|
+
clean_key = clean_key[6:] # Remove 'model.'
|
|
1012
|
+
|
|
1013
|
+
# Map language_ prefixed keys to language_model structure
|
|
1014
|
+
if clean_key.startswith('language_'):
|
|
1015
|
+
if clean_key.startswith('language_layers.'):
|
|
1016
|
+
clean_key = 'language_model.layers.' + clean_key[16:] # Map to language_model.layers.
|
|
1017
|
+
elif clean_key.startswith('language_embed_tokens.'):
|
|
1018
|
+
clean_key = 'language_model.embed_tokens.' + clean_key[22:] # Map to language_model.embed_tokens.
|
|
1019
|
+
elif clean_key.startswith('language_norm.'):
|
|
1020
|
+
clean_key = 'language_model.norm.' + clean_key[14:] # Map to language_model.norm.
|
|
1021
|
+
|
|
1022
|
+
sanitized[clean_key] = v
|
|
1023
|
+
|
|
1024
|
+
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
1025
|
+
if self.args.tie_word_embeddings:
|
|
1026
|
+
sanitized.pop("lm_head.weight", None)
|
|
1027
|
+
|
|
1028
|
+
return sanitized
|
|
1029
|
+
|
|
1030
|
+
@property
|
|
1031
|
+
def layers(self):
|
|
1032
|
+
return self.language_model.layers
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
# Combined Model (for compatibility and utility functions)
|
|
1036
|
+
class Qwen3VLModel(nn.Module):
|
|
1037
|
+
def __init__(self, args: ModelArgs):
|
|
1038
|
+
super().__init__()
|
|
1039
|
+
self.args = args
|
|
1040
|
+
self.config = args
|
|
1041
|
+
self.visual = VisionModel(args.vision_config)
|
|
1042
|
+
self.language_model = TextModel(args.text_config)
|
|
1043
|
+
|
|
1044
|
+
def sanitize(self, weights):
|
|
1045
|
+
# Map weights to match the combined model structure
|
|
1046
|
+
sanitized = {}
|
|
1047
|
+
for k, v in weights.items():
|
|
1048
|
+
# Remove 'model.' prefix if present to match our structure
|
|
1049
|
+
clean_key = k.replace('model.', '') if k.startswith('model.') else k
|
|
1050
|
+
sanitized[clean_key] = v
|
|
1051
|
+
return sanitized
|
|
1052
|
+
|
|
1053
|
+
def get_image_features(
|
|
1054
|
+
self,
|
|
1055
|
+
pixel_values: mx.array,
|
|
1056
|
+
image_grid_thw: Optional[mx.array] = None
|
|
1057
|
+
):
|
|
1058
|
+
image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
|
|
1059
|
+
# Split based on grid dimensions
|
|
1060
|
+
if image_grid_thw is not None:
|
|
1061
|
+
split_sizes = (mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size ** 2)).tolist()
|
|
1062
|
+
# Convert sizes to indices for mx.split (cumulative sum, excluding the last)
|
|
1063
|
+
split_indices = []
|
|
1064
|
+
cumsum = 0
|
|
1065
|
+
for size in split_sizes[:-1]: # Exclude last element
|
|
1066
|
+
cumsum += size
|
|
1067
|
+
split_indices.append(cumsum)
|
|
1068
|
+
|
|
1069
|
+
if split_indices: # Only split if we have indices
|
|
1070
|
+
image_embeds = mx.split(image_embeds, split_indices)
|
|
1071
|
+
else:
|
|
1072
|
+
image_embeds = [image_embeds] # Single image case
|
|
1073
|
+
return image_embeds, deepstack_visual_embeds
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
def __call__(
|
|
1077
|
+
self,
|
|
1078
|
+
input_ids: mx.array = None,
|
|
1079
|
+
attention_mask: Optional[mx.array] = None,
|
|
1080
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
1081
|
+
pixel_values: Optional[mx.array] = None,
|
|
1082
|
+
image_grid_thw: Optional[mx.array] = None,
|
|
1083
|
+
cache=None,
|
|
1084
|
+
visual_pos_masks: Optional[mx.array] = None,
|
|
1085
|
+
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
1086
|
+
cos: Optional[mx.array] = None,
|
|
1087
|
+
sin: Optional[mx.array] = None,
|
|
1088
|
+
rope_deltas: Optional[mx.array] = None,
|
|
1089
|
+
):
|
|
1090
|
+
if inputs_embeds is None:
|
|
1091
|
+
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
|
1092
|
+
|
|
1093
|
+
# Process images
|
|
1094
|
+
|
|
1095
|
+
if pixel_values is not None:
|
|
1096
|
+
image_embeds, deepstack_visual_embeds = self.get_image_features(
|
|
1097
|
+
pixel_values, image_grid_thw
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
# Create masks and embed visual features
|
|
1101
|
+
if isinstance(image_embeds, list):
|
|
1102
|
+
image_embeds = mx.concatenate(image_embeds, axis=0)
|
|
1103
|
+
|
|
1104
|
+
# Find image token positions and replace with visual embeddings
|
|
1105
|
+
image_mask = (input_ids == self.args.image_token_id)
|
|
1106
|
+
visual_pos_masks = image_mask
|
|
1107
|
+
|
|
1108
|
+
# Replace image tokens with visual embeddings
|
|
1109
|
+
inputs_embeds = inputs_embeds.at[image_mask].set(
|
|
1110
|
+
image_embeds.astype(inputs_embeds.dtype)
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
|
|
1114
|
+
outputs = self.language_model(
|
|
1115
|
+
inputs_embeds=inputs_embeds,
|
|
1116
|
+
attention_mask=attention_mask,
|
|
1117
|
+
cache=cache,
|
|
1118
|
+
visual_pos_masks=visual_pos_masks,
|
|
1119
|
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
|
1120
|
+
cos=cos,
|
|
1121
|
+
sin=sin,
|
|
1122
|
+
rope_deltas=rope_deltas,
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
return outputs
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
|
|
1129
|
+
"""
|
|
1130
|
+
Handle the processing of multimodal embeddings including image features and position encoding.
|
|
1131
|
+
|
|
1132
|
+
This function processes vision and text inputs to create unified embeddings that can be fed
|
|
1133
|
+
into the language model. It handles:
|
|
1134
|
+
- Vision feature extraction from pixel values
|
|
1135
|
+
- Deepstack visual embedding collection
|
|
1136
|
+
- Image token replacement in text embeddings
|
|
1137
|
+
- Position encoding setup for MRoPE (Multi-dimensional RoPE)
|
|
1138
|
+
|
|
1139
|
+
Args:
|
|
1140
|
+
vision_model: The vision encoder model (VEGModel instance)
|
|
1141
|
+
llm_model: The language model (LLMModel instance)
|
|
1142
|
+
input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
|
|
1143
|
+
pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
|
|
1144
|
+
image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
|
|
1145
|
+
|
|
1146
|
+
Returns:
|
|
1147
|
+
tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
|
|
1148
|
+
- inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
|
|
1149
|
+
- deepstack_visual_embeds: Multi-layer visual features for deepstack processing
|
|
1150
|
+
- visual_pos_masks: Boolean mask indicating image token positions
|
|
1151
|
+
- cos: Cosine values for rotary position encoding
|
|
1152
|
+
- sin: Sine values for rotary position encoding
|
|
1153
|
+
- rope_deltas: Position offset deltas for rope computation
|
|
1154
|
+
"""
|
|
1155
|
+
inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
|
|
1156
|
+
deepstack_visual_embeds = None
|
|
1157
|
+
visual_pos_masks = None
|
|
1158
|
+
cos = None
|
|
1159
|
+
sin = None
|
|
1160
|
+
rope_deltas = 0
|
|
1161
|
+
|
|
1162
|
+
if pixel_values is not None:
|
|
1163
|
+
if pixel_values.ndim == 4:
|
|
1164
|
+
pixel_values = mx.expand_dims(pixel_values, axis=2)
|
|
1165
|
+
|
|
1166
|
+
# Process each image individually to prevent feature mixing
|
|
1167
|
+
image_embeds_list = []
|
|
1168
|
+
all_deepstack_embeds = []
|
|
1169
|
+
|
|
1170
|
+
# Calculate cumulative indices for each image
|
|
1171
|
+
cumulative_patches = 0
|
|
1172
|
+
|
|
1173
|
+
for i in range(image_grid_thw.shape[0]):
|
|
1174
|
+
# Calculate number of patches for current image
|
|
1175
|
+
current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
|
|
1176
|
+
start_idx = cumulative_patches
|
|
1177
|
+
end_idx = cumulative_patches + current_patches
|
|
1178
|
+
cumulative_patches += current_patches
|
|
1179
|
+
|
|
1180
|
+
single_pixel_values = pixel_values[start_idx:end_idx]
|
|
1181
|
+
single_grid_thw = image_grid_thw[i:i+1]
|
|
1182
|
+
|
|
1183
|
+
# Use vision model directly
|
|
1184
|
+
single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
|
|
1185
|
+
|
|
1186
|
+
# Split based on grid dimensions
|
|
1187
|
+
if single_grid_thw is not None:
|
|
1188
|
+
split_sizes = (mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size ** 2)).tolist()
|
|
1189
|
+
split_indices = []
|
|
1190
|
+
cumsum = 0
|
|
1191
|
+
for size in split_sizes[:-1]:
|
|
1192
|
+
cumsum += size
|
|
1193
|
+
split_indices.append(cumsum)
|
|
1194
|
+
|
|
1195
|
+
if split_indices:
|
|
1196
|
+
single_embeds = mx.split(single_embeds, split_indices)
|
|
1197
|
+
else:
|
|
1198
|
+
single_embeds = [single_embeds]
|
|
1199
|
+
|
|
1200
|
+
image_embeds_list.extend(single_embeds)
|
|
1201
|
+
|
|
1202
|
+
# Collect deepstack embeddings
|
|
1203
|
+
if i == 0:
|
|
1204
|
+
all_deepstack_embeds = single_deepstack
|
|
1205
|
+
else:
|
|
1206
|
+
# Concatenate deepstack embeddings from different images
|
|
1207
|
+
for j in range(len(all_deepstack_embeds)):
|
|
1208
|
+
all_deepstack_embeds[j] = mx.concatenate([all_deepstack_embeds[j], single_deepstack[j]], axis=0)
|
|
1209
|
+
|
|
1210
|
+
deepstack_visual_embeds = all_deepstack_embeds
|
|
1211
|
+
|
|
1212
|
+
# Concatenate all image embeddings for processing
|
|
1213
|
+
image_embeds = mx.concatenate(image_embeds_list, axis=0)
|
|
1214
|
+
|
|
1215
|
+
# Find all image token positions
|
|
1216
|
+
image_token_id = 151655 # Default image token ID
|
|
1217
|
+
image_mask = (input_ids.squeeze(0) == image_token_id)
|
|
1218
|
+
image_mask_np = np.array(image_mask)
|
|
1219
|
+
image_token_positions = np.where(image_mask_np)[0]
|
|
1220
|
+
|
|
1221
|
+
# Verify we have the correct number of image tokens
|
|
1222
|
+
expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
|
|
1223
|
+
assert len(image_token_positions) == expected_total_tokens, f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
|
|
1224
|
+
|
|
1225
|
+
# Replace image tokens with image embeddings
|
|
1226
|
+
seq_len = inputs_embeds.shape[0]
|
|
1227
|
+
result = inputs_embeds
|
|
1228
|
+
|
|
1229
|
+
# Replace image tokens with image embeddings sequentially
|
|
1230
|
+
embed_idx = 0
|
|
1231
|
+
for img_embed in image_embeds_list:
|
|
1232
|
+
for patch_idx in range(img_embed.shape[0]):
|
|
1233
|
+
token_pos = image_token_positions[embed_idx]
|
|
1234
|
+
pos_mask = mx.arange(seq_len) == token_pos
|
|
1235
|
+
result = mx.where(
|
|
1236
|
+
mx.expand_dims(pos_mask, axis=-1),
|
|
1237
|
+
mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
|
|
1238
|
+
result
|
|
1239
|
+
)
|
|
1240
|
+
embed_idx += 1
|
|
1241
|
+
|
|
1242
|
+
inputs_embeds = result
|
|
1243
|
+
position_ids, rope_deltas = llm_model.get_rope_index(input_ids, image_grid_thw)
|
|
1244
|
+
cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
|
|
1245
|
+
if inputs_embeds.ndim == 2:
|
|
1246
|
+
inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
|
|
1247
|
+
|
|
1248
|
+
if image_mask is not None:
|
|
1249
|
+
visual_pos_masks = image_mask
|
|
1250
|
+
|
|
1251
|
+
return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
|
|
1252
|
+
|
|
1253
|
+
|
|
1254
|
+
# Legacy Model wrapper (for backward compatibility)
|
|
1255
|
+
class Model(nn.Module):
|
|
1256
|
+
def __init__(self, args: ModelArgs):
|
|
1257
|
+
super().__init__()
|
|
1258
|
+
self.args = args
|
|
1259
|
+
self.model = Qwen3VLModel(args)
|
|
1260
|
+
if not args.text_config.tie_word_embeddings:
|
|
1261
|
+
self.lm_head = nn.Linear(args.text_config.hidden_size, args.text_config.vocab_size, bias=False)
|
|
1262
|
+
|
|
1263
|
+
def __call__(
|
|
1264
|
+
self,
|
|
1265
|
+
inputs: mx.array = None,
|
|
1266
|
+
mask: mx.array = None,
|
|
1267
|
+
cache=None,
|
|
1268
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
1269
|
+
pixel_values: Optional[mx.array] = None,
|
|
1270
|
+
image_grid_thw: Optional[mx.array] = None,
|
|
1271
|
+
visual_pos_masks: Optional[mx.array] = None,
|
|
1272
|
+
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
1273
|
+
cos: Optional[mx.array] = None,
|
|
1274
|
+
sin: Optional[mx.array] = None,
|
|
1275
|
+
rope_deltas: Optional[mx.array] = None,
|
|
1276
|
+
):
|
|
1277
|
+
out = self.model(
|
|
1278
|
+
input_ids=inputs,
|
|
1279
|
+
inputs_embeds=inputs_embeds,
|
|
1280
|
+
attention_mask=mask,
|
|
1281
|
+
cache=cache,
|
|
1282
|
+
pixel_values=pixel_values,
|
|
1283
|
+
image_grid_thw=image_grid_thw,
|
|
1284
|
+
visual_pos_masks=visual_pos_masks,
|
|
1285
|
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
|
1286
|
+
cos=cos,
|
|
1287
|
+
sin=sin,
|
|
1288
|
+
rope_deltas=rope_deltas,
|
|
1289
|
+
)
|
|
1290
|
+
if self.args.text_config.tie_word_embeddings:
|
|
1291
|
+
return self.model.language_model.embed_tokens.as_linear(out)
|
|
1292
|
+
else:
|
|
1293
|
+
return self.lm_head(out)
|
|
1294
|
+
|
|
1295
|
+
def sanitize(self, weights):
|
|
1296
|
+
# Remove any unnecessary weights
|
|
1297
|
+
sanitized = {}
|
|
1298
|
+
for k, v in weights.items():
|
|
1299
|
+
sanitized[k] = v
|
|
1300
|
+
|
|
1301
|
+
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
1302
|
+
if self.args.text_config.tie_word_embeddings:
|
|
1303
|
+
sanitized.pop("lm_head.weight", None)
|
|
1304
|
+
|
|
1305
|
+
return sanitized
|
|
1306
|
+
|
|
1307
|
+
@property
|
|
1308
|
+
def layers(self):
|
|
1309
|
+
return self.model.language_model.layers
|