nexaai 1.0.14__cp310-cp310-macosx_14_0_universal2.whl → 1.0.15__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Binary file
nexaai/_version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # This file is generated by CMake from _version.py.in
2
2
  # Do not modify this file manually - it will be overwritten
3
3
 
4
- __version__ = "1.0.14"
4
+ __version__ = "1.0.15"
Binary file
Binary file
@@ -0,0 +1,261 @@
1
+ import argparse
2
+ import json
3
+ import sys
4
+ import os
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ import time
8
+ from PIL import Image
9
+ import requests
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from huggingface_hub import snapshot_download
13
+
14
+ # Add current directory to path for imports
15
+ curr_dir = os.path.dirname(os.path.abspath(__file__))
16
+ sys.path.append(curr_dir)
17
+ sys.path.append(os.path.dirname(curr_dir))
18
+
19
+ # Add the qwen3vl model directory to path
20
+ qwen3vl_dir = os.path.join(curr_dir, "modeling", "models", "qwen3_vl")
21
+ sys.path.append(qwen3vl_dir)
22
+
23
+ # Import required modules for quantized loading
24
+ from transformers import AutoTokenizer
25
+
26
+ # Try relative imports first, fallback to sys.path approach for Nuitka compatibility
27
+ try:
28
+ from .modeling.models.qwen3_vl.llm_common.generate import nexa_generate_step
29
+ from .modeling.models.qwen3_vl.llm_common.cache import make_prompt_cache
30
+ from .modeling.models.qwen3_vl.qwen3vl import (
31
+ VEGModel, LLMModel, ModelArgs, VisionConfig, TextConfig, handle_multimodal_embeds
32
+ )
33
+ from .modeling.models.qwen3_vl.processor import Qwen3VLProcessor
34
+ except ImportError:
35
+ # Fallback for Nuitka compiled environment - use sys.path approach
36
+ from llm_common.generate import nexa_generate_step
37
+ from llm_common.cache import make_prompt_cache
38
+ from qwen3vl import VEGModel, LLMModel, ModelArgs, VisionConfig, TextConfig, handle_multimodal_embeds
39
+ from processor import Qwen3VLProcessor
40
+
41
+ from ml import ChatMessage
42
+ from dataclasses import dataclass
43
+ from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
44
+ from .generate import GenerationResult
45
+
46
+ @dataclass
47
+ class Qwen3VLBundledModel:
48
+ """Container for Qwen3-VL vision and language models."""
49
+ vision_model: VEGModel
50
+ llm_model: LLMModel
51
+
52
+
53
+ def _ensure_list(x: Union[str, List[str], None]) -> Optional[List[str]]:
54
+ if x is None:
55
+ return None
56
+ return x if isinstance(x, list) else [x]
57
+
58
+
59
+ def load_qwen3_vl(
60
+ path_or_repo: str,
61
+ adapter_path: Optional[str] = None,
62
+ lazy: bool = False,
63
+ revision: Optional[str] = None,
64
+ **kwargs,
65
+ ) -> Tuple[Qwen3VLBundledModel, Qwen3VLProcessor]:
66
+ """Load Qwen3-VL quantized models and processor.
67
+
68
+ Parameters are aligned with .generate.load for compatibility.
69
+ """
70
+ model_path = Path(path_or_repo)
71
+ if not model_path.exists():
72
+ if "/" in path_or_repo:
73
+ model_path = Path(snapshot_download(
74
+ repo_id=path_or_repo, repo_type="model", revision=revision))
75
+ else:
76
+ # Fallback to local modelfiles directory
77
+ model_path = Path(qwen3vl_dir) / "modelfiles"
78
+ if not model_path.exists():
79
+ model_path = Path(curr_dir) / "modelfiles"
80
+
81
+ # Model configs (kept identical to main)
82
+ vision_config = VisionConfig(
83
+ hidden_size=1024,
84
+ intermediate_size=4096,
85
+ num_heads=16,
86
+ num_hidden_layers=24,
87
+ patch_size=16,
88
+ temporal_patch_size=2,
89
+ in_channels=3,
90
+ hidden_act="gelu",
91
+ spatial_merge_size=2,
92
+ out_hidden_size=2560,
93
+ num_position_embeddings=2304,
94
+ deepstack_visual_indexes=[5, 11, 17],
95
+ )
96
+
97
+ text_config = TextConfig(
98
+ model_type="qwen3vl",
99
+ hidden_size=2560,
100
+ num_hidden_layers=36,
101
+ intermediate_size=9728,
102
+ num_attention_heads=32,
103
+ num_key_value_heads=8,
104
+ rms_norm_eps=1e-6,
105
+ vocab_size=151936,
106
+ max_position_embeddings=32768,
107
+ rope_theta=5000000.0,
108
+ head_dim=128,
109
+ tie_word_embeddings=True,
110
+ attention_bias=False,
111
+ attention_dropout=0.0,
112
+ rope_scaling={"mrope_section": [24, 20, 20],
113
+ "rope_type": "default", "type": "default"},
114
+ )
115
+
116
+ vision_model = VEGModel(vision_config)
117
+ llm_model = LLMModel(text_config)
118
+
119
+ # Try to load LLM model from available files in order of preference
120
+ preferred_order = [
121
+ ("qwen3vl-llm-4B-q4_0.safetensors", 4),
122
+ ("qwen3vl-llm-4B-q8_0.safetensors", 8),
123
+ ("qwen3vl-llm-4B-f32.safetensors", 32)
124
+ ]
125
+
126
+ llm_weights_path = None
127
+ quantization_bits = None
128
+
129
+ # Try loading in order of preference
130
+ for filename, bits in preferred_order:
131
+ candidate_path = model_path / filename
132
+ if candidate_path.exists():
133
+ llm_weights_path = candidate_path
134
+ quantization_bits = bits
135
+ break
136
+
137
+ if llm_weights_path is None:
138
+ # Fallback to original hardcoded path for backward compatibility
139
+ llm_weights_path = model_path / "qwen3vl-llm-4B-q4_0.safetensors"
140
+ quantization_bits = 4
141
+
142
+ vision_weights_path = model_path / "qwen3vl-vision-4B-f16.safetensors"
143
+
144
+ if not vision_weights_path.exists() or not llm_weights_path.exists():
145
+ raise FileNotFoundError(
146
+ f"Missing safetensors. Vision: {vision_weights_path}, LLM: {llm_weights_path}"
147
+ )
148
+
149
+ # Load weights (vision fp16, llm with detected quantization)
150
+ vision_model.set_dtype(mx.float16)
151
+ vision_model.load_weights(str(vision_weights_path), strict=True)
152
+
153
+ # Apply quantization if needed and load LLM weights
154
+ if quantization_bits in [4, 8]:
155
+ nn.quantize(llm_model, bits=quantization_bits, group_size=64,
156
+ class_predicate=quant_predicate)
157
+ # For f32 (32-bit), no quantization needed
158
+
159
+ llm_model.load_weights(str(llm_weights_path), strict=True)
160
+
161
+ # Tokenizer and processor
162
+ tokenizer = AutoTokenizer.from_pretrained(path_or_repo)
163
+ processor = Qwen3VLProcessor(tokenizer=tokenizer)
164
+
165
+ return Qwen3VLBundledModel(vision_model=vision_model, llm_model=llm_model), processor
166
+
167
+ def apply_chat_template_qwen3_vl(messages: Sequence[ChatMessage], num_images: int = 0, num_audios: int = 0, tools: Optional[str] = None, enable_thinking: bool = False) -> str:
168
+ """Apply chat template: serialize messages with content as a list of typed items."""
169
+ messages_dict = []
170
+ for msg in messages:
171
+ content_items = [{"type": "text", "text": msg.content}]
172
+ messages_dict.append({"role": msg.role, "content": content_items})
173
+ return json.dumps(messages_dict)
174
+
175
+
176
+ def stream_generate_qwen3_vl(
177
+ model: Qwen3VLBundledModel,
178
+ processor: Qwen3VLProcessor,
179
+ prompt: str,
180
+ image: Union[str, List[str]] = None,
181
+ audio: Union[str, List[str]] = None,
182
+ max_tokens: int = 512,
183
+ **kwargs,
184
+
185
+ ) -> Generator[Any, None, None]:
186
+ """Stream generation yielding .generate.GenerationResult-compatible chunks."""
187
+ messages = json.loads(prompt)
188
+ if image is not None:
189
+ image_list = image if isinstance(image, list) else [image]
190
+ pil_images = []
191
+ for p in image_list:
192
+ try:
193
+ pil_images.append(Image.open(p))
194
+ except Exception:
195
+ continue
196
+ contents = [{"type": "image", "image": img} for img in pil_images]
197
+ if messages:
198
+ if "content" not in messages[-1] or not isinstance(messages[-1]["content"], list):
199
+ messages[-1]["content"] = []
200
+ messages[-1]["content"].extend(contents)
201
+
202
+ raw_text, processed_images = processor.messages_to_text(
203
+ messages, add_generation_prompt=True)
204
+
205
+ inputs = processor.text_to_input_ids(
206
+ raw_text, images=processed_images, return_tensors="mlx")
207
+
208
+ input_ids = inputs["input_ids"]
209
+ pixel_values = inputs.get("pixel_values")
210
+ image_grid_thw = inputs.get("image_grid_thw")
211
+
212
+ inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas = handle_multimodal_embeds(
213
+ model.vision_model, model.llm_model, input_ids, pixel_values, image_grid_thw
214
+ )
215
+
216
+ prompt_cache = make_prompt_cache(model.llm_model, max_kv_size=4096)
217
+ tokenizer = processor.tokenizer
218
+
219
+ # Rough prompt TPS estimation based on input size
220
+ prompt_start = time.perf_counter()
221
+ prompt_tps = input_ids.size / max(1e-6, (time.perf_counter() - prompt_start))
222
+
223
+ gen_count = 0
224
+ tic = time.perf_counter()
225
+
226
+ for token, logprobs in nexa_generate_step(
227
+ model=model.llm_model,
228
+ prompt=None,
229
+ input_embeddings=inputs_embeds,
230
+ max_tokens=max_tokens,
231
+ max_kv_size=4096,
232
+ prompt_cache=prompt_cache,
233
+ visual_pos_masks=visual_pos_masks,
234
+ deepstack_visual_embeds=deepstack_visual_embeds,
235
+ cos=cos,
236
+ sin=sin,
237
+ rope_deltas=rope_deltas,
238
+ ):
239
+ if token == tokenizer.eos_token_id:
240
+ break
241
+
242
+ text_piece = tokenizer.decode([token])
243
+ gen_count += 1
244
+
245
+ yield GenerationResult(
246
+ text=text_piece,
247
+ token=token,
248
+ logprobs=logprobs,
249
+ prompt_tokens=int(input_ids.size),
250
+ generation_tokens=gen_count,
251
+ prompt_tps=float(prompt_tps),
252
+ generation_tps=float(
253
+ gen_count / max(1e-6, (time.perf_counter() - tic))),
254
+ peak_memory=float(mx.get_peak_memory() / 1e9),
255
+ )
256
+
257
+ def quant_predicate(path: str, mod: nn.Module) -> bool:
258
+ """Quantization predicate to exclude certain layers from quantization."""
259
+ if path.endswith("lm_head") or "norm" in path.lower() or "embed" in path.lower():
260
+ return False
261
+ return isinstance(mod, (nn.Linear, nn.Embedding))
@@ -25,6 +25,8 @@ from profiling import ProfilingMixin, ProfilingData, StopReason
25
25
 
26
26
  # Import from the actual mlx_vlm structure
27
27
  from .generate import generate, stream_generate, load
28
+ from .generate_qwen3_vl import apply_chat_template_qwen3_vl, stream_generate_qwen3_vl, load_qwen3_vl
29
+
28
30
  from .modeling.prompt_utils import apply_chat_template
29
31
 
30
32
  # --------------------------------------------------------------------------------------
@@ -54,6 +56,7 @@ class VLM(ProfilingMixin):
54
56
 
55
57
  def __init__(
56
58
  self,
59
+ model_name: Optional[str],
57
60
  model_path: Path,
58
61
  mmproj_path: Path,
59
62
  context_length: int,
@@ -67,11 +70,13 @@ class VLM(ProfilingMixin):
67
70
  model_path = os.path.dirname(model_path)
68
71
 
69
72
  self.model_path = model_path
73
+ self.model_name = model_name
70
74
  self.mmproj_path = mmproj_path
71
75
  self.context_length = context_length
72
76
  self.device = device
73
77
 
74
- self.model, self.processor = load(str(model_path))
78
+ load_impl = load_qwen3_vl if model_name == "qwen3vl" else load
79
+ self.model, self.processor = load_impl(str(model_path))
75
80
 
76
81
  # Init deafutl sampler config with defualt.
77
82
  self.sampler_config = SamplerConfig()
@@ -228,9 +233,10 @@ class VLM(ProfilingMixin):
228
233
  text = ""
229
234
  last_result = None
230
235
  first_token = True
236
+ stream_generate_impl = stream_generate_qwen3_vl if self.model_name == "qwen3vl" else stream_generate
231
237
 
232
238
  try:
233
- for result in stream_generate(
239
+ for result in stream_generate_impl(
234
240
  self.model,
235
241
  self.processor,
236
242
  prompt,
@@ -350,6 +356,9 @@ class VLM(ProfilingMixin):
350
356
 
351
357
  def apply_chat_template_with_media(self, messages: Sequence[ChatMessage], num_images: int = 0, num_audios: int = 0, tools: Optional[str] = None, enable_thinking: bool = True) -> str:
352
358
  """Apply chat template to messages with proper image/audio token insertion and optional tools support."""
359
+ if self.model_name == "qwen3vl":
360
+ return apply_chat_template_qwen3_vl(messages, num_images=num_images, num_audios=num_audios, tools=tools, enable_thinking=enable_thinking)
361
+
353
362
  # Convert ChatMessage objects to dictionaries for the processor
354
363
  messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
355
364
 
@@ -3,6 +3,7 @@ from ml import GenerationConfig, SamplerConfig, ChatMessage
3
3
  import re
4
4
  import os
5
5
  import codecs
6
+ import argparse
6
7
 
7
8
  def parse_media_from_input(user_input):
8
9
  """Parse quoted media files from user input and return prompt and media paths"""
@@ -39,12 +40,177 @@ def parse_media_from_input(user_input):
39
40
 
40
41
  return prompt, image_paths if image_paths else None, audio_paths if audio_paths else None
41
42
 
42
- def test_vlm_generate_stream(model_path):
43
+ def parse_arguments():
44
+ """Parse command line arguments for the VLM main function."""
45
+ parser = argparse.ArgumentParser(
46
+ description="Interactive VLM (Vision-Language Model) conversation interface."
47
+ )
48
+ parser.add_argument(
49
+ "--model_path",
50
+ type=str,
51
+ default="mlx-community/gemma-3-4b-it-8bit",
52
+ help="The path to the local model directory or Hugging Face repo."
53
+ )
54
+ parser.add_argument(
55
+ "--model_name",
56
+ type=str,
57
+ default="",
58
+ help="Specific model name/type (e.g., 'qwen3vl', 'gemma3'). If empty, auto-detect from model_path."
59
+ )
60
+ parser.add_argument(
61
+ "--context_length",
62
+ type=int,
63
+ default=2048,
64
+ help="Context length for the model (default: 2048)."
65
+ )
66
+ parser.add_argument(
67
+ "--temperature",
68
+ type=float,
69
+ default=0.7,
70
+ help="Sampling temperature (default: 0.7)."
71
+ )
72
+ parser.add_argument(
73
+ "--top_p",
74
+ type=float,
75
+ default=0.9,
76
+ help="Top-p sampling parameter (default: 0.9)."
77
+ )
78
+ parser.add_argument(
79
+ "--max_tokens",
80
+ type=int,
81
+ default=512,
82
+ help="Maximum tokens to generate (default: 512)."
83
+ )
84
+ return parser.parse_args()
85
+
86
+ def main():
87
+ """Main function for interactive VLM conversation."""
88
+ args = parse_arguments()
89
+
90
+ # Auto-detect model name if not provided
91
+ model_name = args.model_name
92
+ if not model_name:
93
+ if "qwen" in args.model_path.lower():
94
+ model_name = "qwen3vl"
95
+ elif "gemma" in args.model_path.lower():
96
+ model_name = "gemma3"
97
+ else:
98
+ model_name = ""
99
+
100
+ # Load the VLM instance
101
+ vlm = VLM(
102
+ model_name=model_name,
103
+ model_path=args.model_path,
104
+ mmproj_path=None, # Not needed for this model
105
+ context_length=args.context_length,
106
+ device=None
107
+ )
108
+
109
+ # Configure sampler
110
+ sampler_config = SamplerConfig(
111
+ temperature=args.temperature,
112
+ top_p=args.top_p
113
+ )
114
+ vlm.set_sampler(sampler_config)
115
+
116
+ # Chat history using ChatMessage objects
117
+ chat = []
118
+
119
+ print("VLM Multi-round conversation started. Type 'quit' or 'exit' to end.")
120
+ print("Include images/audios in quotes, e.g.: 'describe \"image1.jpg\" \"image2.png\"'")
121
+ print("You can also use single quotes: 'describe '/path/to/image.jpg''")
122
+ print("=" * 50)
123
+
124
+ def on_token(text_chunk):
125
+ """Token callback for streaming"""
126
+ print(text_chunk, end="", flush=True)
127
+ return True
128
+
129
+ while True:
130
+ # Get user input
131
+ user_input = input("\nUser: ").strip()
132
+
133
+ # Check for exit commands
134
+ if user_input.lower() in ["quit", "exit", "q"]:
135
+ print("Goodbye!")
136
+ break
137
+
138
+ if not user_input:
139
+ continue
140
+
141
+ # Parse media files and prompt from user input
142
+ prompt_text, image_paths, audio_paths = parse_media_from_input(user_input)
143
+
144
+ # If no text prompt after parsing, use the original input
145
+ if not prompt_text.strip():
146
+ prompt_text = user_input
147
+ image_paths = None
148
+ audio_paths = None
149
+
150
+ # Add user message to chat history using ChatMessage
151
+ chat.append(ChatMessage(role="user", content=prompt_text))
152
+
153
+ # Calculate number of images and audios for chat template
154
+ num_images = len(image_paths) if image_paths else 0
155
+ num_audios = len(audio_paths) if audio_paths else 0
156
+
157
+ # Apply chat template with image/audio token insertion
158
+ try:
159
+ formatted_prompt = vlm.apply_chat_template_with_media(chat, num_images=num_images, num_audios=num_audios)
160
+ except (NotImplementedError, AttributeError):
161
+ # Fallback to manual formatting if chat template is not implemented
162
+ formatted_prompt = ""
163
+ for msg in chat:
164
+ formatted_prompt += f"{msg.role}: {msg.content}\n"
165
+ formatted_prompt += "Assistant: "
166
+
167
+ # Generation config with media paths
168
+ generation_config = GenerationConfig(
169
+ max_tokens=args.max_tokens,
170
+ sampler_config=sampler_config,
171
+ image_paths=image_paths,
172
+ audio_paths=audio_paths
173
+ )
174
+
175
+ # Generate response
176
+ print("Assistant: ", end="", flush=True)
177
+
178
+ try:
179
+ # Use streaming generation with callback
180
+ response_text = ""
181
+
182
+ def token_callback(text_chunk):
183
+ nonlocal response_text
184
+ print(text_chunk, end="", flush=True)
185
+ response_text += text_chunk
186
+ return True
187
+
188
+ # Use generate_stream method for streaming generation
189
+ response = vlm.generate_stream(
190
+ prompt=formatted_prompt,
191
+ config=generation_config,
192
+ on_token=token_callback
193
+ )
194
+
195
+ print() # New line after streaming
196
+
197
+ # Add assistant response to chat history using ChatMessage
198
+ chat.append(ChatMessage(role="assistant", content=response_text))
199
+
200
+ except Exception as e:
201
+ print(f"Error generating response: {e}")
202
+ print()
203
+
204
+ # Clean up
205
+ vlm.destroy()
206
+
207
+ def test_vlm_generate_stream(model_path, model_name):
43
208
  # Specify the checkpoint
44
209
  context_length = 2048
45
210
 
46
211
  # Load the corresponding model and VLM instance
47
212
  vlm = VLM(
213
+ model_name=model_name,
48
214
  model_path=model_path,
49
215
  mmproj_path=None, # Not needed for this model
50
216
  context_length=context_length,
@@ -88,9 +254,6 @@ def test_vlm_generate_stream(model_path):
88
254
  # Parse media files and prompt from user input
89
255
  prompt_text, image_paths, audio_paths = parse_media_from_input(user_input)
90
256
 
91
- print(f"image_paths: {image_paths}")
92
- print(f"audio_paths: {audio_paths}")
93
-
94
257
  # If no text prompt after parsing, use the original input
95
258
  if not prompt_text.strip():
96
259
  prompt_text = user_input
@@ -150,8 +313,4 @@ def test_vlm_generate_stream(model_path):
150
313
  vlm.destroy()
151
314
 
152
315
  if __name__ == "__main__":
153
- import argparse
154
- parser = argparse.ArgumentParser()
155
- parser.add_argument("--model_path", type=str, default="mlx-community/gemma-3-4b-it-8bit")
156
- args = parser.parse_args()
157
- test_vlm_generate_stream(args.model_path)
316
+ main()
@@ -0,0 +1,117 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Any, Optional
4
+
5
+ import mlx.core as mx
6
+ from mlx.utils import tree_map
7
+
8
+ from .cache import QuantizedKVCache
9
+
10
+
11
+ @dataclass
12
+ class BaseModelArgs:
13
+ @classmethod
14
+ def from_dict(cls, params):
15
+ return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
16
+
17
+
18
+ def create_causal_mask(
19
+ N: int,
20
+ offset: int = 0,
21
+ window_size: Optional[int] = None,
22
+ lengths: Optional[mx.array] = None,
23
+ ):
24
+ rinds = mx.arange(offset + N)
25
+ linds = mx.arange(offset, offset + N) if offset else rinds
26
+ linds = linds[:, None]
27
+ rinds = rinds[None]
28
+ mask = linds >= rinds
29
+ if window_size is not None:
30
+ mask = mask & (linds <= rinds + window_size)
31
+ if lengths is not None:
32
+ lengths = lengths[:, None, None, None]
33
+ mask = mask & (rinds < lengths)
34
+ return mask
35
+
36
+
37
+ def create_attention_mask(h: mx.array, cache: Optional[Any] = None, return_array: bool = False):
38
+ T = h.shape[1]
39
+ if T > 1:
40
+ offset = 0
41
+ window_size = None
42
+ if cache is not None and cache[0] is not None:
43
+ c = cache[0]
44
+ offset = c.offset
45
+ if hasattr(c, "max_size"):
46
+ window_size = c.max_size
47
+ offset = min(window_size, offset)
48
+ return_array = return_array or offset + T > window_size
49
+ if return_array:
50
+ return create_causal_mask(T, offset, window_size=window_size)
51
+ else:
52
+ return "causal"
53
+ else:
54
+ mask = None
55
+ return mask
56
+
57
+
58
+ def quantized_scaled_dot_product_attention(
59
+ queries: mx.array,
60
+ q_keys: tuple[mx.array, mx.array, mx.array],
61
+ q_values: tuple[mx.array, mx.array, mx.array],
62
+ scale: float,
63
+ mask: Optional[mx.array],
64
+ group_size: int = 64,
65
+ bits: int = 8,
66
+ ) -> mx.array:
67
+ B, n_q_heads, L, D = queries.shape
68
+ n_kv_heads = q_keys[0].shape[-3]
69
+ n_repeats = n_q_heads // n_kv_heads
70
+
71
+ queries *= scale
72
+
73
+ if n_repeats > 1:
74
+ queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
75
+ q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
76
+ q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
77
+
78
+ scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
79
+ if mask is not None:
80
+ if isinstance(mask, str):
81
+ qL, kL = scores.shape[-2:]
82
+ q_indices = mx.arange(kL - qL, kL)
83
+ k_indices = mx.arange(kL)
84
+ mask = q_indices[:, None] >= k_indices[None]
85
+ if mask.dtype == mx.bool_:
86
+ scores = mx.where(mask, scores, mx.finfo(scores.dtype).min)
87
+ else:
88
+ scores += mask
89
+ scores = mx.softmax(scores, axis=-1, precise=True)
90
+ out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
91
+
92
+ if n_repeats > 1:
93
+ out = mx.reshape(out, (B, n_q_heads, L, D))
94
+
95
+ return out
96
+
97
+
98
+ def scaled_dot_product_attention(
99
+ queries,
100
+ keys,
101
+ values,
102
+ cache,
103
+ scale: float,
104
+ mask: Optional[mx.array],
105
+ ) -> mx.array:
106
+ if isinstance(cache, QuantizedKVCache):
107
+ return quantized_scaled_dot_product_attention(
108
+ queries,
109
+ keys,
110
+ values,
111
+ scale=scale,
112
+ mask=mask,
113
+ group_size=cache.group_size,
114
+ bits=cache.bits,
115
+ )
116
+ else:
117
+ return mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)