fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
mlx_vlm/generate.py
ADDED
|
@@ -0,0 +1,1457 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import codecs
|
|
3
|
+
import contextlib
|
|
4
|
+
import functools
|
|
5
|
+
import json
|
|
6
|
+
import time
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import mlx.core as mx
|
|
11
|
+
import mlx.nn as nn
|
|
12
|
+
from mlx.utils import tree_reduce
|
|
13
|
+
from mlx_lm.generate import maybe_quantize_kv_cache
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from transformers import PreTrainedTokenizer
|
|
16
|
+
|
|
17
|
+
from .models import cache
|
|
18
|
+
from .prompt_utils import apply_chat_template
|
|
19
|
+
from .sample_utils import top_p_sampling
|
|
20
|
+
from .utils import (
|
|
21
|
+
StoppingCriteria,
|
|
22
|
+
apply_repetition_penalty,
|
|
23
|
+
group_images_by_shape,
|
|
24
|
+
load,
|
|
25
|
+
prepare_inputs,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit"
|
|
29
|
+
DEFAULT_IMAGE = None
|
|
30
|
+
DEFAULT_AUDIO = None
|
|
31
|
+
DEFAULT_PROMPT = "What are these?"
|
|
32
|
+
DEFAULT_MAX_TOKENS = 256
|
|
33
|
+
DEFAULT_TEMPERATURE = 0.5
|
|
34
|
+
DEFAULT_TOP_P = 1.0
|
|
35
|
+
DEFAULT_SEED = 0
|
|
36
|
+
DEFAULT_QUANTIZED_KV_START = 5000
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def parse_arguments():
|
|
40
|
+
parser = argparse.ArgumentParser(
|
|
41
|
+
description="Generate text from an image using a model."
|
|
42
|
+
)
|
|
43
|
+
parser.add_argument(
|
|
44
|
+
"--model",
|
|
45
|
+
type=str,
|
|
46
|
+
default=DEFAULT_MODEL_PATH,
|
|
47
|
+
help="The path to the local model directory or Hugging Face repo.",
|
|
48
|
+
)
|
|
49
|
+
parser.add_argument(
|
|
50
|
+
"--adapter-path",
|
|
51
|
+
type=str,
|
|
52
|
+
default=None,
|
|
53
|
+
help="The path to the adapter weights.",
|
|
54
|
+
)
|
|
55
|
+
parser.add_argument(
|
|
56
|
+
"--image",
|
|
57
|
+
type=str,
|
|
58
|
+
nargs="+",
|
|
59
|
+
default=DEFAULT_IMAGE,
|
|
60
|
+
help="URL or path of the image to process.",
|
|
61
|
+
)
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
"--audio",
|
|
64
|
+
type=str,
|
|
65
|
+
nargs="+",
|
|
66
|
+
default=DEFAULT_AUDIO,
|
|
67
|
+
help="URL or path of the audio to process.",
|
|
68
|
+
)
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"--resize-shape",
|
|
71
|
+
type=int,
|
|
72
|
+
nargs="+",
|
|
73
|
+
default=None,
|
|
74
|
+
help="Resize shape for the image.",
|
|
75
|
+
)
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--prompt",
|
|
78
|
+
type=str,
|
|
79
|
+
nargs="+",
|
|
80
|
+
default=DEFAULT_PROMPT,
|
|
81
|
+
help="Message to be processed by the model.",
|
|
82
|
+
)
|
|
83
|
+
parser.add_argument(
|
|
84
|
+
"--system",
|
|
85
|
+
type=str,
|
|
86
|
+
default=None,
|
|
87
|
+
help="System message for the model.",
|
|
88
|
+
)
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
"--max-tokens",
|
|
91
|
+
type=int,
|
|
92
|
+
default=DEFAULT_MAX_TOKENS,
|
|
93
|
+
help="Maximum number of tokens to generate.",
|
|
94
|
+
)
|
|
95
|
+
parser.add_argument(
|
|
96
|
+
"--temperature",
|
|
97
|
+
type=float,
|
|
98
|
+
default=DEFAULT_TEMPERATURE,
|
|
99
|
+
help="Temperature for sampling.",
|
|
100
|
+
)
|
|
101
|
+
parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.")
|
|
102
|
+
parser.add_argument("--verbose", action="store_false", help="Detailed output.")
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--eos-tokens",
|
|
105
|
+
type=str,
|
|
106
|
+
nargs="+",
|
|
107
|
+
default=None,
|
|
108
|
+
help="EOS tokens to add to the tokenizer.",
|
|
109
|
+
)
|
|
110
|
+
parser.add_argument(
|
|
111
|
+
"--max-kv-size",
|
|
112
|
+
type=int,
|
|
113
|
+
default=None,
|
|
114
|
+
help="Maximum KV size for the prompt cache.",
|
|
115
|
+
)
|
|
116
|
+
parser.add_argument(
|
|
117
|
+
"--kv-bits",
|
|
118
|
+
type=int,
|
|
119
|
+
default=None,
|
|
120
|
+
help="Number of bits to quantize the KV cache to.",
|
|
121
|
+
)
|
|
122
|
+
parser.add_argument(
|
|
123
|
+
"--kv-group-size",
|
|
124
|
+
type=int,
|
|
125
|
+
default=64,
|
|
126
|
+
help="Group size for the KV cache.",
|
|
127
|
+
)
|
|
128
|
+
parser.add_argument(
|
|
129
|
+
"--quantized-kv-start",
|
|
130
|
+
type=int,
|
|
131
|
+
default=DEFAULT_QUANTIZED_KV_START,
|
|
132
|
+
help="Start index for the quantized KV cache.",
|
|
133
|
+
)
|
|
134
|
+
parser.add_argument(
|
|
135
|
+
"--skip-special-tokens",
|
|
136
|
+
action="store_true",
|
|
137
|
+
help="Skip special tokens in the detokenizer.",
|
|
138
|
+
)
|
|
139
|
+
parser.add_argument(
|
|
140
|
+
"--force-download",
|
|
141
|
+
action="store_true",
|
|
142
|
+
help="Force download the model from Hugging Face.",
|
|
143
|
+
)
|
|
144
|
+
parser.add_argument(
|
|
145
|
+
"--revision",
|
|
146
|
+
type=str,
|
|
147
|
+
default="main",
|
|
148
|
+
help="The specific model version to use (branch, tag, commit).",
|
|
149
|
+
)
|
|
150
|
+
parser.add_argument(
|
|
151
|
+
"--trust-remote-code",
|
|
152
|
+
action="store_true",
|
|
153
|
+
help="Trust remote code when loading the model.",
|
|
154
|
+
)
|
|
155
|
+
parser.add_argument(
|
|
156
|
+
"--processor-kwargs",
|
|
157
|
+
type=json.loads,
|
|
158
|
+
default={},
|
|
159
|
+
help="Extra processor kwargs as JSON. "
|
|
160
|
+
'Example: --processor-kwargs \'{"cropping": false, "max_patches": 3}\'',
|
|
161
|
+
)
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"--prefill-step-size",
|
|
164
|
+
type=int,
|
|
165
|
+
default=None,
|
|
166
|
+
help="Number of tokens to process per prefill step. "
|
|
167
|
+
"Lower values reduce peak memory usage but may be slower. "
|
|
168
|
+
"Try 512 or 256 if you hit GPU memory errors during prefill.",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return parser.parse_args()
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# A stream on the default device just for generation
|
|
175
|
+
generation_stream = mx.new_stream(mx.default_device())
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@contextlib.contextmanager
|
|
179
|
+
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
|
180
|
+
"""
|
|
181
|
+
A context manager to temporarily change the wired limit.
|
|
182
|
+
|
|
183
|
+
Note, the wired limit should not be changed during an async eval. If an
|
|
184
|
+
async eval could be running pass in the streams to synchronize with prior
|
|
185
|
+
to exiting the context manager.
|
|
186
|
+
"""
|
|
187
|
+
if not mx.metal.is_available():
|
|
188
|
+
yield
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
model_bytes = tree_reduce(
|
|
192
|
+
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
193
|
+
)
|
|
194
|
+
max_rec_size = mx.device_info()["max_recommended_working_set_size"]
|
|
195
|
+
if model_bytes > 0.9 * max_rec_size:
|
|
196
|
+
model_mb = model_bytes // 2**20
|
|
197
|
+
max_rec_mb = max_rec_size // 2**20
|
|
198
|
+
print(
|
|
199
|
+
f"[WARNING] Generating with a model that requires {model_mb} MB "
|
|
200
|
+
f"which is close to the maximum recommended size of {max_rec_mb} "
|
|
201
|
+
"MB. This can be slow. See the documentation for possible work-arounds: "
|
|
202
|
+
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
|
203
|
+
)
|
|
204
|
+
old_limit = mx.set_wired_limit(max_rec_size)
|
|
205
|
+
try:
|
|
206
|
+
yield
|
|
207
|
+
finally:
|
|
208
|
+
if streams is not None:
|
|
209
|
+
for s in streams:
|
|
210
|
+
mx.synchronize(s)
|
|
211
|
+
else:
|
|
212
|
+
mx.synchronize()
|
|
213
|
+
mx.set_wired_limit(old_limit)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@dataclass
|
|
217
|
+
class GenerationResult:
|
|
218
|
+
text: str = ""
|
|
219
|
+
token: Optional[int] = None
|
|
220
|
+
logprobs: Optional[List[float]] = None
|
|
221
|
+
prompt_tokens: int = 0
|
|
222
|
+
generation_tokens: int = 0
|
|
223
|
+
total_tokens: int = 0
|
|
224
|
+
prompt_tps: float = 0.0
|
|
225
|
+
generation_tps: float = 0.0
|
|
226
|
+
peak_memory: float = 0.0
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def generate_step(
|
|
230
|
+
input_ids: mx.array,
|
|
231
|
+
model: nn.Module,
|
|
232
|
+
pixel_values,
|
|
233
|
+
mask,
|
|
234
|
+
*,
|
|
235
|
+
max_tokens: int = 256,
|
|
236
|
+
temperature: float = 0.0,
|
|
237
|
+
repetition_penalty: Optional[float] = None,
|
|
238
|
+
repetition_context_size: Optional[int] = 20,
|
|
239
|
+
top_p: float = 1.0,
|
|
240
|
+
logit_bias: Optional[Dict[int, float]] = None,
|
|
241
|
+
prompt_cache: Optional[List[Any]] = None,
|
|
242
|
+
max_kv_size: Optional[int] = None,
|
|
243
|
+
kv_bits: Optional[int] = None,
|
|
244
|
+
kv_group_size: int = 64,
|
|
245
|
+
quantized_kv_start: int = 0,
|
|
246
|
+
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
|
247
|
+
prefill_step_size: Optional[int] = 2048,
|
|
248
|
+
**kwargs,
|
|
249
|
+
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
250
|
+
"""
|
|
251
|
+
A generator producing token ids based on the given prompt from the model.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
input_ids (mx.array): The input prompt token ids.
|
|
255
|
+
model (nn.Module): The model to use for generation.
|
|
256
|
+
pixel_values: The pixel values for vision models (optional).
|
|
257
|
+
mask: The attention mask (optional).
|
|
258
|
+
max_tokens (int): Maximum number of tokens to generate. Default: ``256``.
|
|
259
|
+
temperature (float): The temperature for sampling, if 0 the argmax is used.
|
|
260
|
+
Default: ``0``.
|
|
261
|
+
repetition_penalty (float, optional): The penalty factor for repeating
|
|
262
|
+
tokens.
|
|
263
|
+
repetition_context_size (int, optional): The number of tokens to
|
|
264
|
+
consider for repetition penalty. Default: ``20``.
|
|
265
|
+
top_p (float, optional): Nucleus sampling, higher means model considers
|
|
266
|
+
more less likely words.
|
|
267
|
+
logit_bias (dictionary, optional): Additive logit bias.
|
|
268
|
+
prompt_cache (list, optional): Pre-existing KV cache for the prompt.
|
|
269
|
+
max_kv_size (int, optional): Maximum KV cache size.
|
|
270
|
+
kv_bits (int, optional): Number of bits for KV cache quantization.
|
|
271
|
+
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
|
272
|
+
quantized_kv_start (int): Start index for quantized KV cache. Default: ``0``.
|
|
273
|
+
logits_processors (list, optional): List of logits processor functions.
|
|
274
|
+
prefill_step_size (int): Number of tokens to process per prefill step.
|
|
275
|
+
Chunked prefill processes prompts in smaller chunks to reduce peak
|
|
276
|
+
memory usage. Default: ``2048``.
|
|
277
|
+
|
|
278
|
+
Yields:
|
|
279
|
+
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
|
280
|
+
one token and a vector of log probabilities.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
quantize_cache_fn = functools.partial(
|
|
284
|
+
maybe_quantize_kv_cache,
|
|
285
|
+
quantized_kv_start=quantized_kv_start,
|
|
286
|
+
kv_group_size=kv_group_size,
|
|
287
|
+
kv_bits=kv_bits,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
|
291
|
+
if logit_bias:
|
|
292
|
+
indices = mx.array(list(logit_bias.keys()))
|
|
293
|
+
values = mx.array(list(logit_bias.values()))
|
|
294
|
+
logits[:, indices] += values
|
|
295
|
+
logprobs = logits - mx.logsumexp(logits)
|
|
296
|
+
|
|
297
|
+
if temperature == 0:
|
|
298
|
+
token = mx.argmax(logits, axis=-1)
|
|
299
|
+
else:
|
|
300
|
+
if top_p > 0 and top_p < 1.0:
|
|
301
|
+
token = top_p_sampling(logits, top_p, temperature)
|
|
302
|
+
else:
|
|
303
|
+
token = mx.random.categorical(logits * (1 / temperature))
|
|
304
|
+
|
|
305
|
+
return token, logprobs
|
|
306
|
+
|
|
307
|
+
if repetition_penalty and (
|
|
308
|
+
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
|
309
|
+
):
|
|
310
|
+
raise ValueError(
|
|
311
|
+
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
y = input_ids
|
|
315
|
+
tokens = None # Track tokens for logits processors
|
|
316
|
+
|
|
317
|
+
# Create the KV cache for generation
|
|
318
|
+
if prompt_cache is None:
|
|
319
|
+
prompt_cache = cache.make_prompt_cache(
|
|
320
|
+
model.language_model,
|
|
321
|
+
max_kv_size=max_kv_size,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
repetition_context = input_ids.reshape(-1).tolist()
|
|
325
|
+
|
|
326
|
+
if repetition_context_size:
|
|
327
|
+
repetition_context = repetition_context[-repetition_context_size:]
|
|
328
|
+
|
|
329
|
+
def _step(y, inputs_embeds=None):
|
|
330
|
+
nonlocal tokens, repetition_context, kwargs
|
|
331
|
+
with mx.stream(generation_stream):
|
|
332
|
+
if "decoder_input_ids" in kwargs:
|
|
333
|
+
outputs = model.language_model(
|
|
334
|
+
cache=prompt_cache,
|
|
335
|
+
**kwargs,
|
|
336
|
+
)
|
|
337
|
+
else:
|
|
338
|
+
outputs = model.language_model(
|
|
339
|
+
y,
|
|
340
|
+
inputs_embeds=inputs_embeds,
|
|
341
|
+
cache=prompt_cache,
|
|
342
|
+
**kwargs,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
logits = outputs.logits[:, -1, :]
|
|
346
|
+
|
|
347
|
+
# Apply logits processors before repetition penalty
|
|
348
|
+
if logits_processors:
|
|
349
|
+
# Efficiently update tokens by concatenating only the new token
|
|
350
|
+
tokens = mx.concat([tokens, y])
|
|
351
|
+
for processor in logits_processors:
|
|
352
|
+
logits = processor(tokens, logits)
|
|
353
|
+
|
|
354
|
+
if repetition_penalty:
|
|
355
|
+
logits = apply_repetition_penalty(
|
|
356
|
+
logits, repetition_context, repetition_penalty
|
|
357
|
+
)
|
|
358
|
+
y, logprobs = sample(logits)
|
|
359
|
+
repetition_context.append(y.item())
|
|
360
|
+
else:
|
|
361
|
+
y, logprobs = sample(logits)
|
|
362
|
+
|
|
363
|
+
if repetition_context_size:
|
|
364
|
+
if len(repetition_context) > repetition_context_size:
|
|
365
|
+
repetition_context = repetition_context[-repetition_context_size:]
|
|
366
|
+
|
|
367
|
+
quantize_cache_fn(prompt_cache)
|
|
368
|
+
|
|
369
|
+
if outputs.cross_attention_states is not None:
|
|
370
|
+
kwargs = {"cross_attention_states": outputs.cross_attention_states}
|
|
371
|
+
elif outputs.encoder_outputs is not None:
|
|
372
|
+
kwargs = {
|
|
373
|
+
"decoder_input_ids": y[None],
|
|
374
|
+
"encoder_outputs": outputs.encoder_outputs,
|
|
375
|
+
}
|
|
376
|
+
else:
|
|
377
|
+
kwargs = {}
|
|
378
|
+
|
|
379
|
+
return y, logprobs.squeeze(0)
|
|
380
|
+
|
|
381
|
+
with mx.stream(generation_stream):
|
|
382
|
+
|
|
383
|
+
# Get input embeddings (handles both multimodal and text-only)
|
|
384
|
+
embedding_output = model.get_input_embeddings(
|
|
385
|
+
input_ids, pixel_values, mask=mask, **kwargs
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
inputs_embeds = embedding_output.inputs_embeds
|
|
389
|
+
|
|
390
|
+
kwargs.update(
|
|
391
|
+
{
|
|
392
|
+
k: v
|
|
393
|
+
for k, v in embedding_output.to_dict().items()
|
|
394
|
+
if k != "inputs_embeds" and v is not None
|
|
395
|
+
}
|
|
396
|
+
)
|
|
397
|
+
if prefill_step_size is not None and inputs_embeds.shape[1] > prefill_step_size:
|
|
398
|
+
# Chunked prefill with embeddings
|
|
399
|
+
total_tokens = inputs_embeds.shape[1]
|
|
400
|
+
with tqdm(total=total_tokens, desc="Prefill", unit="tok") as pbar:
|
|
401
|
+
while inputs_embeds.shape[1] > 1:
|
|
402
|
+
n_to_process = min(prefill_step_size, inputs_embeds.shape[1] - 1)
|
|
403
|
+
model.language_model(
|
|
404
|
+
inputs=input_ids[:, :n_to_process],
|
|
405
|
+
inputs_embeds=inputs_embeds[:, :n_to_process],
|
|
406
|
+
cache=prompt_cache,
|
|
407
|
+
**kwargs,
|
|
408
|
+
)
|
|
409
|
+
quantize_cache_fn(prompt_cache)
|
|
410
|
+
mx.eval([c.state for c in prompt_cache])
|
|
411
|
+
inputs_embeds = inputs_embeds[:, n_to_process:]
|
|
412
|
+
input_ids = input_ids[:, n_to_process:]
|
|
413
|
+
mx.clear_cache()
|
|
414
|
+
pbar.update(n_to_process)
|
|
415
|
+
|
|
416
|
+
input_ids = input_ids[:, -1:]
|
|
417
|
+
|
|
418
|
+
y, logprobs = _step(input_ids, inputs_embeds=inputs_embeds)
|
|
419
|
+
|
|
420
|
+
mx.async_eval(y)
|
|
421
|
+
|
|
422
|
+
n = 0
|
|
423
|
+
while True:
|
|
424
|
+
if n != max_tokens:
|
|
425
|
+
next_y, next_logprobs = _step(y[None])
|
|
426
|
+
mx.async_eval(next_y)
|
|
427
|
+
if n == 0:
|
|
428
|
+
mx.eval(y)
|
|
429
|
+
if n == max_tokens:
|
|
430
|
+
break
|
|
431
|
+
|
|
432
|
+
yield y.item(), logprobs
|
|
433
|
+
if n % 256 == 0:
|
|
434
|
+
mx.clear_cache()
|
|
435
|
+
y, logprobs = next_y, next_logprobs
|
|
436
|
+
n += 1
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def stream_generate(
|
|
440
|
+
model: nn.Module,
|
|
441
|
+
processor: PreTrainedTokenizer,
|
|
442
|
+
prompt: str,
|
|
443
|
+
image: Union[str, List[str]] = None,
|
|
444
|
+
audio: Union[str, List[str]] = None,
|
|
445
|
+
**kwargs,
|
|
446
|
+
) -> Union[str, Generator[str, None, None]]:
|
|
447
|
+
"""
|
|
448
|
+
A generator producing text based on the given prompt from the model.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
model (nn.Module): The model to use for generation.
|
|
452
|
+
processor (PreTrainedTokenizer): The tokenizer/processor.
|
|
453
|
+
prompt (str): The input prompt text.
|
|
454
|
+
image (Union[str, List[str]], optional): Image path(s) or URL(s).
|
|
455
|
+
audio (Union[str, List[str]], optional): Audio file path(s).
|
|
456
|
+
prefill_step_size (int, optional): Number of tokens to process per prefill
|
|
457
|
+
step. When set, enables chunked prefill which processes long prompts in
|
|
458
|
+
smaller chunks to reduce peak memory usage.
|
|
459
|
+
kwargs: Additional options passed to :func:`generate_step`.
|
|
460
|
+
See :func:`generate_step` for more details.
|
|
461
|
+
|
|
462
|
+
Yields:
|
|
463
|
+
Generator[GenerationResult]: A generator producing GenerationResult objects
|
|
464
|
+
containing the generated text, tokens, and statistics.
|
|
465
|
+
"""
|
|
466
|
+
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
467
|
+
|
|
468
|
+
# Skip special tokens
|
|
469
|
+
skip_special_tokens = kwargs.pop("skip_special_tokens", False)
|
|
470
|
+
skip_special_token_ids = (
|
|
471
|
+
set(tokenizer.all_special_ids)
|
|
472
|
+
if skip_special_tokens and hasattr(tokenizer, "all_special_ids")
|
|
473
|
+
else []
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
add_special_tokens = (
|
|
477
|
+
not hasattr(processor, "chat_template")
|
|
478
|
+
if model.config.model_type in ["gemma3", "gemma3n"]
|
|
479
|
+
else True
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
resize_shape = kwargs.pop("resize_shape", None)
|
|
483
|
+
image_token_index = getattr(model.config, "image_token_index", None)
|
|
484
|
+
|
|
485
|
+
if kwargs.get("input_ids", None) is not None:
|
|
486
|
+
input_ids = kwargs.pop("input_ids")
|
|
487
|
+
pixel_values = kwargs.pop("pixel_values", None)
|
|
488
|
+
mask = kwargs.pop("mask", None)
|
|
489
|
+
else:
|
|
490
|
+
inputs = prepare_inputs(
|
|
491
|
+
processor,
|
|
492
|
+
images=image,
|
|
493
|
+
audio=audio,
|
|
494
|
+
prompts=prompt,
|
|
495
|
+
image_token_index=image_token_index,
|
|
496
|
+
resize_shape=resize_shape,
|
|
497
|
+
add_special_tokens=add_special_tokens,
|
|
498
|
+
**kwargs,
|
|
499
|
+
)
|
|
500
|
+
input_ids = inputs.get("input_ids", None)
|
|
501
|
+
pixel_values = inputs.get("pixel_values", None)
|
|
502
|
+
mask = inputs.get("attention_mask", None)
|
|
503
|
+
data_kwargs = {
|
|
504
|
+
k: v
|
|
505
|
+
for k, v in inputs.items()
|
|
506
|
+
if k not in ["input_ids", "pixel_values", "attention_mask"]
|
|
507
|
+
}
|
|
508
|
+
kwargs.update(data_kwargs)
|
|
509
|
+
|
|
510
|
+
with wired_limit(model, [generation_stream]):
|
|
511
|
+
detokenizer = processor.detokenizer
|
|
512
|
+
detokenizer.reset()
|
|
513
|
+
tic = time.perf_counter()
|
|
514
|
+
|
|
515
|
+
# #region agent log
|
|
516
|
+
import json
|
|
517
|
+
log_file = "/Users/zekieldee/Desktop/code/mlx-vlm/.cursor/debug.log"
|
|
518
|
+
def log_debug(location, message, data, hypothesis_id):
|
|
519
|
+
try:
|
|
520
|
+
with open(log_file, "a") as f:
|
|
521
|
+
f.write(json.dumps({"sessionId": "debug-session", "runId": "generation", "hypothesisId": hypothesis_id, "location": location, "message": message, "data": data, "timestamp": __import__("time").time_ns() // 1000000}) + "\n")
|
|
522
|
+
except: pass
|
|
523
|
+
|
|
524
|
+
log_debug("generate.py:stream_generate_start", "Tokenizer and model info", {
|
|
525
|
+
"model_type": model.config.model_type if hasattr(model.config, "model_type") else "unknown",
|
|
526
|
+
"tokenizer_class": tokenizer.__class__.__name__,
|
|
527
|
+
"vocab_size": tokenizer.vocab_size if hasattr(tokenizer, "vocab_size") else "unknown",
|
|
528
|
+
"eos_token_id": tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else "unknown",
|
|
529
|
+
"bos_token_id": tokenizer.bos_token_id if hasattr(tokenizer, "bos_token_id") else "unknown",
|
|
530
|
+
"pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, "pad_token_id") else "unknown",
|
|
531
|
+
}, "H2,H3,H4")
|
|
532
|
+
# #endregion
|
|
533
|
+
|
|
534
|
+
try:
|
|
535
|
+
for n, (token, logprobs) in enumerate(
|
|
536
|
+
generate_step(input_ids, model, pixel_values, mask, **kwargs)
|
|
537
|
+
):
|
|
538
|
+
if n == 0:
|
|
539
|
+
prompt_time = time.perf_counter() - tic
|
|
540
|
+
prompt_tps = input_ids.size / prompt_time
|
|
541
|
+
tic = time.perf_counter()
|
|
542
|
+
|
|
543
|
+
# #region agent log
|
|
544
|
+
top5_indices = mx.argsort(logprobs)[-5:].tolist()
|
|
545
|
+
top5_values = mx.sort(logprobs)[-5:].tolist()
|
|
546
|
+
log_debug("generate.py:first_token", "First token generated", {
|
|
547
|
+
"token_id": int(token),
|
|
548
|
+
"token_str": tokenizer.decode([token]) if hasattr(tokenizer, "decode") else "N/A",
|
|
549
|
+
"logprobs_shape": str(logprobs.shape),
|
|
550
|
+
"logprobs_top5_indices": top5_indices,
|
|
551
|
+
"logprobs_top5_values": top5_values,
|
|
552
|
+
}, "H2,H4")
|
|
553
|
+
# #endregion
|
|
554
|
+
|
|
555
|
+
# Stop generation if the token is in the eos_token_ids
|
|
556
|
+
if tokenizer.stopping_criteria(token):
|
|
557
|
+
# #region agent log
|
|
558
|
+
log_debug("generate.py:eos_detected", "EOS token detected", {"token_id": int(token), "iteration": n}, "H4")
|
|
559
|
+
# #endregion
|
|
560
|
+
break
|
|
561
|
+
|
|
562
|
+
# #region agent log
|
|
563
|
+
if n < 5: # Log first 5 tokens
|
|
564
|
+
decoded_token = tokenizer.decode([token]) if hasattr(tokenizer, "decode") else "N/A"
|
|
565
|
+
log_debug("generate.py:token_decode", f"Token {n} decode", {
|
|
566
|
+
"iteration": n,
|
|
567
|
+
"token_id": int(token),
|
|
568
|
+
"decoded_single": decoded_token,
|
|
569
|
+
"detokenizer_segment": detokenizer.text if hasattr(detokenizer, "text") else "N/A",
|
|
570
|
+
}, "H2,H4")
|
|
571
|
+
# #endregion
|
|
572
|
+
|
|
573
|
+
detokenizer.add_token(
|
|
574
|
+
token, skip_special_token_ids=skip_special_token_ids
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
# Yield the last segment if streaming
|
|
578
|
+
yield GenerationResult(
|
|
579
|
+
text=detokenizer.last_segment,
|
|
580
|
+
token=token,
|
|
581
|
+
logprobs=logprobs,
|
|
582
|
+
prompt_tokens=input_ids.size,
|
|
583
|
+
generation_tokens=n + 1,
|
|
584
|
+
total_tokens=input_ids.size + n + 1,
|
|
585
|
+
prompt_tps=prompt_tps,
|
|
586
|
+
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
587
|
+
peak_memory=mx.get_peak_memory() / 1e9,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
detokenizer.finalize()
|
|
591
|
+
|
|
592
|
+
yield GenerationResult(
|
|
593
|
+
text=detokenizer.last_segment,
|
|
594
|
+
token=token,
|
|
595
|
+
logprobs=logprobs,
|
|
596
|
+
prompt_tokens=input_ids.size,
|
|
597
|
+
generation_tokens=n + 1,
|
|
598
|
+
total_tokens=input_ids.size + n + 1,
|
|
599
|
+
prompt_tps=prompt_tps,
|
|
600
|
+
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
601
|
+
peak_memory=mx.get_peak_memory() / 1e9,
|
|
602
|
+
)
|
|
603
|
+
except Exception as e:
|
|
604
|
+
raise
|
|
605
|
+
|
|
606
|
+
# Cleanup after generation
|
|
607
|
+
mx.clear_cache()
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def generate(
|
|
611
|
+
model: nn.Module,
|
|
612
|
+
processor: PreTrainedTokenizer,
|
|
613
|
+
prompt: str,
|
|
614
|
+
image: Union[str, List[str]] = None,
|
|
615
|
+
audio: Union[str, List[str]] = None,
|
|
616
|
+
verbose: bool = False,
|
|
617
|
+
**kwargs,
|
|
618
|
+
) -> GenerationResult:
|
|
619
|
+
"""
|
|
620
|
+
Generate text from the model.
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
model (nn.Module): The language model.
|
|
624
|
+
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
625
|
+
prompt (str): The string prompt.
|
|
626
|
+
temperature (float): The temperature for sampling (default 0).
|
|
627
|
+
max_tokens (int): The maximum number of tokens (default 100).
|
|
628
|
+
verbose (bool): If ``True``, print tokens and timing information
|
|
629
|
+
(default ``False``).
|
|
630
|
+
formatter (Optional[Callable]): A function which takes a token and a
|
|
631
|
+
probability and displays it.
|
|
632
|
+
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
|
633
|
+
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
if verbose:
|
|
637
|
+
print("=" * 10)
|
|
638
|
+
files = []
|
|
639
|
+
if image is not None:
|
|
640
|
+
files.extend(image)
|
|
641
|
+
if audio is not None:
|
|
642
|
+
files.extend(audio)
|
|
643
|
+
if kwargs.get("video") is not None:
|
|
644
|
+
files.extend(kwargs.get("video"))
|
|
645
|
+
|
|
646
|
+
print(f"Files: {files}", "\n")
|
|
647
|
+
|
|
648
|
+
print("Prompt:", prompt)
|
|
649
|
+
|
|
650
|
+
text = ""
|
|
651
|
+
last_response = None
|
|
652
|
+
|
|
653
|
+
eos_tokens = kwargs.get("eos_tokens", None)
|
|
654
|
+
stopping_criteria = kwargs.get("stopping_criteria", None)
|
|
655
|
+
|
|
656
|
+
# Get the tokenizer
|
|
657
|
+
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
658
|
+
|
|
659
|
+
# Add custom EOS tokens to the stopping criteria
|
|
660
|
+
if eos_tokens is not None:
|
|
661
|
+
tokenizer.stopping_criteria.add_eos_token_ids(eos_tokens)
|
|
662
|
+
|
|
663
|
+
# Use custom stopping criteria
|
|
664
|
+
elif stopping_criteria is not None:
|
|
665
|
+
if isinstance(stopping_criteria, StoppingCriteria) or callable(
|
|
666
|
+
stopping_criteria
|
|
667
|
+
):
|
|
668
|
+
tokenizer.stopping_criteria = stopping_criteria
|
|
669
|
+
else:
|
|
670
|
+
raise ValueError(
|
|
671
|
+
"stopping_criteria must be an instance of StoppingCriteria or a callable"
|
|
672
|
+
)
|
|
673
|
+
else:
|
|
674
|
+
tokenizer.stopping_criteria.reset(model.config.eos_token_id)
|
|
675
|
+
|
|
676
|
+
for response in stream_generate(model, processor, prompt, image, audio, **kwargs):
|
|
677
|
+
if verbose:
|
|
678
|
+
print(response.text, end="", flush=True)
|
|
679
|
+
text += response.text
|
|
680
|
+
last_response = response
|
|
681
|
+
|
|
682
|
+
if verbose:
|
|
683
|
+
print("\n" + "=" * 10)
|
|
684
|
+
if len(text) == 0:
|
|
685
|
+
print("No text generated for this prompt")
|
|
686
|
+
return GenerationResult(
|
|
687
|
+
text=text,
|
|
688
|
+
token=None,
|
|
689
|
+
logprobs=None,
|
|
690
|
+
prompt_tokens=0,
|
|
691
|
+
generation_tokens=0,
|
|
692
|
+
total_tokens=0,
|
|
693
|
+
prompt_tps=0.0,
|
|
694
|
+
generation_tps=0.0,
|
|
695
|
+
peak_memory=mx.get_peak_memory() / 1e9,
|
|
696
|
+
)
|
|
697
|
+
print(
|
|
698
|
+
f"Prompt: {last_response.prompt_tokens} tokens, "
|
|
699
|
+
f"{last_response.prompt_tps:.3f} tokens-per-sec"
|
|
700
|
+
)
|
|
701
|
+
print(
|
|
702
|
+
f"Generation: {last_response.generation_tokens} tokens, "
|
|
703
|
+
f"{last_response.generation_tps:.3f} tokens-per-sec"
|
|
704
|
+
)
|
|
705
|
+
print(f"Peak memory: {last_response.peak_memory:.3f} GB")
|
|
706
|
+
|
|
707
|
+
return GenerationResult(
|
|
708
|
+
text=text,
|
|
709
|
+
token=last_response.token,
|
|
710
|
+
logprobs=last_response.logprobs,
|
|
711
|
+
prompt_tokens=last_response.prompt_tokens,
|
|
712
|
+
generation_tokens=last_response.generation_tokens,
|
|
713
|
+
total_tokens=last_response.total_tokens,
|
|
714
|
+
prompt_tps=last_response.prompt_tps,
|
|
715
|
+
generation_tps=last_response.generation_tps,
|
|
716
|
+
peak_memory=last_response.peak_memory,
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
@dataclass
|
|
721
|
+
class BatchGenerationResult:
|
|
722
|
+
"""
|
|
723
|
+
Result of batch generation with optional image size tracking.
|
|
724
|
+
|
|
725
|
+
Attributes:
|
|
726
|
+
texts: Generated text for each sample
|
|
727
|
+
tokens: Last generated token for each sample
|
|
728
|
+
logprobs: Log probabilities for each sample
|
|
729
|
+
prompt_tokens: Number of prompt tokens per sample
|
|
730
|
+
generation_tokens: Number of generated tokens per sample
|
|
731
|
+
total_tokens: Total tokens (prompt + generation) per sample
|
|
732
|
+
prompt_tps: Prompt tokens per second per sample
|
|
733
|
+
generation_tps: Generation tokens per second per sample
|
|
734
|
+
peak_memory: Peak memory usage in GB
|
|
735
|
+
image_sizes: Original (height, width) for each image (for tracking)
|
|
736
|
+
"""
|
|
737
|
+
|
|
738
|
+
texts: List[str]
|
|
739
|
+
tokens: List[Optional[int]]
|
|
740
|
+
logprobs: List[Optional[List[float]]]
|
|
741
|
+
prompt_tokens: List[int]
|
|
742
|
+
generation_tokens: List[int]
|
|
743
|
+
total_tokens: List[int]
|
|
744
|
+
prompt_tps: List[float]
|
|
745
|
+
generation_tps: List[float]
|
|
746
|
+
peak_memory: float = 0.0
|
|
747
|
+
image_sizes: Optional[List[Tuple[int, int]]] = None
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def _left_pad_prompts(prompts, max_length=None):
|
|
751
|
+
if max_length is None:
|
|
752
|
+
max_length = max(len(p) for p in prompts)
|
|
753
|
+
|
|
754
|
+
return mx.array([[0] * (max_length - len(p)) + p for p in prompts])
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
def _make_cache(model, left_padding):
|
|
758
|
+
"""
|
|
759
|
+
Convert a list of regular caches into their corresponding
|
|
760
|
+
batch-aware caches.
|
|
761
|
+
"""
|
|
762
|
+
|
|
763
|
+
def to_batch_cache(c):
|
|
764
|
+
if isinstance(c, cache.KVCache):
|
|
765
|
+
return cache.BatchKVCache(left_padding)
|
|
766
|
+
elif isinstance(c, cache.ArraysCache):
|
|
767
|
+
c.left_padding = mx.array(left_padding)
|
|
768
|
+
return c
|
|
769
|
+
elif isinstance(c, cache.RotatingKVCache):
|
|
770
|
+
if c.keep > 0:
|
|
771
|
+
raise ValueError("RotatingKVCache with keep tokens is not supported.")
|
|
772
|
+
return cache.BatchRotatingKVCache(c.max_size, left_padding)
|
|
773
|
+
elif isinstance(c, cache.CacheList):
|
|
774
|
+
return cache.BatchCacheList(*(to_batch_cache(sub_c) for sub_c in c.caches))
|
|
775
|
+
else:
|
|
776
|
+
raise ValueError(f"{type(c)} does not yet support batching")
|
|
777
|
+
|
|
778
|
+
if hasattr(model, "make_cache"):
|
|
779
|
+
model_cache = model.make_cache()
|
|
780
|
+
return [to_batch_cache(c) for c in model_cache]
|
|
781
|
+
else:
|
|
782
|
+
return [cache.BatchKVCache(left_padding) for _ in model.layers]
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
@dataclass
|
|
786
|
+
class BatchStats:
|
|
787
|
+
"""
|
|
788
|
+
An data object to hold generation stats.
|
|
789
|
+
|
|
790
|
+
Args:
|
|
791
|
+
prompt_tokens (int): The number of prompt tokens processed.
|
|
792
|
+
prompt_tps (float): The prompt processing tokens-per-second.
|
|
793
|
+
prompt_time (float): The time in seconds spent in prompt processing.
|
|
794
|
+
generation_tokens (int): The number of generated tokens.
|
|
795
|
+
generation_tps (float): The tokens-per-second for generation.
|
|
796
|
+
generation_time (float): The time in seconds spent in generation .
|
|
797
|
+
peak_memory (float): The peak memory used so far in GB.
|
|
798
|
+
"""
|
|
799
|
+
|
|
800
|
+
prompt_tokens: int = 0
|
|
801
|
+
prompt_tps: float = 0
|
|
802
|
+
prompt_time: float = 0
|
|
803
|
+
generation_tokens: int = 0
|
|
804
|
+
generation_tps: float = 0
|
|
805
|
+
generation_time: float = 0
|
|
806
|
+
peak_memory: float = 0
|
|
807
|
+
|
|
808
|
+
|
|
809
|
+
@dataclass
|
|
810
|
+
class BatchResponse:
|
|
811
|
+
"""
|
|
812
|
+
An data object to hold a batch generation response.
|
|
813
|
+
|
|
814
|
+
Args:
|
|
815
|
+
texts: (List[str]): The generated text for each prompt.
|
|
816
|
+
stats (BatchStats): Statistics about the generation.
|
|
817
|
+
image_sizes: (Optional[List[Tuple[int, int]]]): Original (height, width)
|
|
818
|
+
for each image. Useful for tracking which images produced which responses
|
|
819
|
+
and for debugging padding/batching behavior.
|
|
820
|
+
"""
|
|
821
|
+
|
|
822
|
+
texts: List[str]
|
|
823
|
+
stats: BatchStats
|
|
824
|
+
image_sizes: Optional[List[Tuple[int, int]]] = None
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
@dataclass
|
|
828
|
+
class Batch:
|
|
829
|
+
uids: List[int]
|
|
830
|
+
y: mx.array
|
|
831
|
+
logprobs: mx.array
|
|
832
|
+
max_tokens: List[int]
|
|
833
|
+
num_tokens: List[int]
|
|
834
|
+
cache: List[Any]
|
|
835
|
+
|
|
836
|
+
def __len__(self):
|
|
837
|
+
return len(self.uids)
|
|
838
|
+
|
|
839
|
+
def filter(self, keep_idx: List[int]):
|
|
840
|
+
self.uids = [self.uids[k] for k in keep_idx]
|
|
841
|
+
self.max_tokens = [self.max_tokens[k] for k in keep_idx]
|
|
842
|
+
self.num_tokens = [self.num_tokens[k] for k in keep_idx]
|
|
843
|
+
keep_idx = mx.array(keep_idx, mx.int32)
|
|
844
|
+
self.y = self.y[keep_idx]
|
|
845
|
+
self.logprobs = self.logprobs[keep_idx]
|
|
846
|
+
for c in self.cache:
|
|
847
|
+
c.filter(keep_idx)
|
|
848
|
+
|
|
849
|
+
def extend(self, other):
|
|
850
|
+
self.uids.extend(other.uids)
|
|
851
|
+
self.y = mx.concatenate([self.y, other.y])
|
|
852
|
+
self.logprobs = mx.concatenate([self.logprobs, other.logprobs])
|
|
853
|
+
self.num_tokens.extend(other.num_tokens)
|
|
854
|
+
self.max_tokens.extend(other.max_tokens)
|
|
855
|
+
for c, o in zip(self.cache, other.cache):
|
|
856
|
+
c.extend(o)
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
class BatchGenerator:
|
|
860
|
+
|
|
861
|
+
@dataclass
|
|
862
|
+
class Response:
|
|
863
|
+
uid: int
|
|
864
|
+
token: int
|
|
865
|
+
logprobs: mx.array
|
|
866
|
+
finish_reason: Optional[str]
|
|
867
|
+
|
|
868
|
+
def __init__(
|
|
869
|
+
self,
|
|
870
|
+
model,
|
|
871
|
+
processor,
|
|
872
|
+
max_tokens: int = 128,
|
|
873
|
+
stop_tokens: Optional[set] = None,
|
|
874
|
+
sampler: Optional[Callable[[mx.array], mx.array]] = None,
|
|
875
|
+
completion_batch_size: int = 32,
|
|
876
|
+
prefill_batch_size: int = 8,
|
|
877
|
+
prefill_step_size: int = 2048,
|
|
878
|
+
prompt_cache=None,
|
|
879
|
+
):
|
|
880
|
+
self.model = model
|
|
881
|
+
self.unprocessed_prompts = []
|
|
882
|
+
self.max_tokens = max_tokens
|
|
883
|
+
self.processor = processor
|
|
884
|
+
self.tokenizer = (
|
|
885
|
+
processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
886
|
+
)
|
|
887
|
+
self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
|
888
|
+
self.uid_count = 0
|
|
889
|
+
self.prefill_step_size = prefill_step_size
|
|
890
|
+
self.prefill_batch_size = prefill_batch_size
|
|
891
|
+
self.completion_batch_size = completion_batch_size
|
|
892
|
+
self.prompt_cache = prompt_cache
|
|
893
|
+
self._stats = BatchStats()
|
|
894
|
+
|
|
895
|
+
self.tokenizer.stopping_criteria.add_eos_token_ids(stop_tokens)
|
|
896
|
+
|
|
897
|
+
self.active_batch = None
|
|
898
|
+
|
|
899
|
+
def insert(self, prompts, max_tokens: Union[List[int], int, None] = None):
|
|
900
|
+
uids = []
|
|
901
|
+
|
|
902
|
+
if max_tokens is None or isinstance(max_tokens, int):
|
|
903
|
+
max_tokens = [max_tokens or self.max_tokens] * len(prompts)
|
|
904
|
+
|
|
905
|
+
for p, m in zip(prompts, max_tokens):
|
|
906
|
+
self.unprocessed_prompts.append((self.uid_count, p, m))
|
|
907
|
+
uids.append(self.uid_count)
|
|
908
|
+
self.uid_count += 1
|
|
909
|
+
# Sort in ascending order of length
|
|
910
|
+
self.unprocessed_prompts = sorted(
|
|
911
|
+
self.unprocessed_prompts, key=lambda x: len(x[1])
|
|
912
|
+
)
|
|
913
|
+
return uids
|
|
914
|
+
|
|
915
|
+
def _process_prompts(self, prompts, **kwargs) -> Batch:
|
|
916
|
+
uids, inputs, max_tokens = zip(*prompts)
|
|
917
|
+
lengths = [len(p) for p in inputs]
|
|
918
|
+
max_length = max(lengths)
|
|
919
|
+
|
|
920
|
+
self._stats.prompt_tokens += sum(lengths)
|
|
921
|
+
left_padding = [max_length - l for l in lengths]
|
|
922
|
+
inputs = _left_pad_prompts(inputs, max_length=max_length)
|
|
923
|
+
|
|
924
|
+
prompt_cache = (
|
|
925
|
+
_make_cache(self.model, left_padding)
|
|
926
|
+
if self.prompt_cache is None
|
|
927
|
+
else self.prompt_cache
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
# Slice batch data in kwargs to match current batch size
|
|
931
|
+
batch_size = len(uids)
|
|
932
|
+
for key, value in kwargs.items():
|
|
933
|
+
if isinstance(value, mx.array) and value.ndim > 0:
|
|
934
|
+
kwargs[key] = value[:batch_size]
|
|
935
|
+
|
|
936
|
+
inputs_embeds = kwargs.pop("inputs_embeds", None)
|
|
937
|
+
|
|
938
|
+
if inputs_embeds is not None:
|
|
939
|
+
# Multimodal prefill
|
|
940
|
+
while inputs_embeds.shape[1] > 1:
|
|
941
|
+
n_to_process = min(self.prefill_step_size, inputs_embeds.shape[1] - 1)
|
|
942
|
+
self.model(
|
|
943
|
+
inputs[:, :n_to_process],
|
|
944
|
+
cache=prompt_cache,
|
|
945
|
+
inputs_embeds=inputs_embeds[:, :n_to_process],
|
|
946
|
+
n_to_process=n_to_process,
|
|
947
|
+
**kwargs,
|
|
948
|
+
)
|
|
949
|
+
mx.eval([c.state for c in prompt_cache])
|
|
950
|
+
inputs_embeds = inputs_embeds[:, n_to_process:]
|
|
951
|
+
inputs = inputs[:, n_to_process:]
|
|
952
|
+
mx.clear_cache()
|
|
953
|
+
|
|
954
|
+
kwargs = {"inputs_embeds": inputs_embeds}
|
|
955
|
+
|
|
956
|
+
else:
|
|
957
|
+
# Text-only prefill
|
|
958
|
+
while inputs.shape[1] > 1 and inputs_embeds is None:
|
|
959
|
+
n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1)
|
|
960
|
+
self.model(inputs[:, :n_to_process], cache=prompt_cache)
|
|
961
|
+
mx.eval([c.state for c in prompt_cache])
|
|
962
|
+
inputs = inputs[:, n_to_process:]
|
|
963
|
+
mx.clear_cache()
|
|
964
|
+
|
|
965
|
+
y, logprobs = self._step(inputs, prompt_cache, **kwargs)
|
|
966
|
+
mx.async_eval(y, logprobs)
|
|
967
|
+
mx.clear_cache()
|
|
968
|
+
return Batch(
|
|
969
|
+
list(uids), y, logprobs, list(max_tokens), [0] * len(uids), prompt_cache
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
def _step(self, input_tokens: mx.array, prompt_cache: List[Any], **kwargs):
|
|
973
|
+
output = self.model(input_tokens, cache=prompt_cache, **kwargs)
|
|
974
|
+
logits = output.logits[:, -1, :]
|
|
975
|
+
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
|
976
|
+
sampled = self.sampler(logprobs)
|
|
977
|
+
|
|
978
|
+
# TODO: Add KV cache quantization if specified
|
|
979
|
+
return sampled, logprobs
|
|
980
|
+
|
|
981
|
+
def stats(self):
|
|
982
|
+
self._stats.prompt_tps = self._stats.prompt_tokens / self._stats.prompt_time
|
|
983
|
+
self._stats.generation_tps = (
|
|
984
|
+
self._stats.generation_tokens / self._stats.generation_time
|
|
985
|
+
)
|
|
986
|
+
self._stats.peak_memory = mx.get_peak_memory() / 1e9
|
|
987
|
+
return self._stats
|
|
988
|
+
|
|
989
|
+
def _next(self, **kwargs):
|
|
990
|
+
tic = time.perf_counter()
|
|
991
|
+
|
|
992
|
+
prompt_processing = False
|
|
993
|
+
batch = self.active_batch
|
|
994
|
+
num_active = len(batch) if batch else 0
|
|
995
|
+
num_to_add = self.completion_batch_size - num_active
|
|
996
|
+
while num_to_add >= self.prefill_batch_size:
|
|
997
|
+
prompts = self.unprocessed_prompts[: self.prefill_batch_size]
|
|
998
|
+
# Finish processing the last examples of the last batch
|
|
999
|
+
if len(prompts) == 0 and num_active > 0:
|
|
1000
|
+
break
|
|
1001
|
+
# No more prompts and no more completions, all done
|
|
1002
|
+
elif len(prompts) == 0:
|
|
1003
|
+
self.active_batch = None
|
|
1004
|
+
return []
|
|
1005
|
+
# Process prompts
|
|
1006
|
+
if batch is not None and not prompt_processing:
|
|
1007
|
+
# Finish any active completion tokens
|
|
1008
|
+
mx.eval(batch.y, batch.logprobs)
|
|
1009
|
+
self._stats.generation_time += time.perf_counter() - tic
|
|
1010
|
+
tic = time.perf_counter()
|
|
1011
|
+
|
|
1012
|
+
batch = self._process_prompts(prompts, **kwargs)
|
|
1013
|
+
self.unprocessed_prompts = self.unprocessed_prompts[
|
|
1014
|
+
self.prefill_batch_size :
|
|
1015
|
+
]
|
|
1016
|
+
prompt_processing = True
|
|
1017
|
+
# If there was no active batch, set it
|
|
1018
|
+
if self.active_batch is None:
|
|
1019
|
+
self.active_batch = batch
|
|
1020
|
+
else:
|
|
1021
|
+
self.active_batch.extend(batch)
|
|
1022
|
+
|
|
1023
|
+
num_active = len(self.active_batch)
|
|
1024
|
+
num_to_add -= len(batch)
|
|
1025
|
+
|
|
1026
|
+
batch = self.active_batch
|
|
1027
|
+
y, logprobs = batch.y, batch.logprobs
|
|
1028
|
+
batch.y, batch.logprobs = self._step(y[:, None], batch.cache)
|
|
1029
|
+
mx.async_eval(batch.y, batch.logprobs)
|
|
1030
|
+
|
|
1031
|
+
y = y.tolist()
|
|
1032
|
+
toc = time.perf_counter()
|
|
1033
|
+
if prompt_processing:
|
|
1034
|
+
self._stats.prompt_time += toc - tic
|
|
1035
|
+
else:
|
|
1036
|
+
self._stats.generation_time += toc - tic
|
|
1037
|
+
keep_idx = []
|
|
1038
|
+
end_idx = []
|
|
1039
|
+
responses = []
|
|
1040
|
+
|
|
1041
|
+
for e, (t, uid, num_tok, max_tok) in enumerate(
|
|
1042
|
+
zip(y, batch.uids, batch.num_tokens, batch.max_tokens)
|
|
1043
|
+
):
|
|
1044
|
+
num_tok += 1
|
|
1045
|
+
batch.num_tokens[e] = num_tok
|
|
1046
|
+
if self.tokenizer.stopping_criteria(t):
|
|
1047
|
+
finish_reason = "stop"
|
|
1048
|
+
end_idx.append(e)
|
|
1049
|
+
elif num_tok >= max_tok:
|
|
1050
|
+
finish_reason = "length"
|
|
1051
|
+
end_idx.append(e)
|
|
1052
|
+
else:
|
|
1053
|
+
finish_reason = None
|
|
1054
|
+
keep_idx.append(e)
|
|
1055
|
+
responses.append(self.Response(uid, t, logprobs[e], finish_reason))
|
|
1056
|
+
|
|
1057
|
+
# Remove any finished completions
|
|
1058
|
+
if len(end_idx):
|
|
1059
|
+
if len(keep_idx) > 0:
|
|
1060
|
+
batch.filter(keep_idx)
|
|
1061
|
+
else:
|
|
1062
|
+
self.active_batch = None
|
|
1063
|
+
|
|
1064
|
+
self._stats.generation_tokens += len(responses)
|
|
1065
|
+
|
|
1066
|
+
if len(responses) > 0 and self._stats.generation_tokens % 100 == 0:
|
|
1067
|
+
mx.clear_cache()
|
|
1068
|
+
|
|
1069
|
+
return responses
|
|
1070
|
+
|
|
1071
|
+
def next(self, **kwargs):
|
|
1072
|
+
with mx.stream(generation_stream):
|
|
1073
|
+
return self._next(**kwargs)
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
def batch_generate(
|
|
1077
|
+
model,
|
|
1078
|
+
processor,
|
|
1079
|
+
images: Union[str, List[str]] = None,
|
|
1080
|
+
audios: Union[str, List[str]] = None,
|
|
1081
|
+
prompts: List[str] = None,
|
|
1082
|
+
max_tokens: Union[int, List[int]] = 128,
|
|
1083
|
+
verbose: bool = False,
|
|
1084
|
+
group_by_shape: bool = True,
|
|
1085
|
+
track_image_sizes: bool = True,
|
|
1086
|
+
**kwargs,
|
|
1087
|
+
):
|
|
1088
|
+
"""
|
|
1089
|
+
Generate responses for the given batch of prompts with variable-sized images.
|
|
1090
|
+
|
|
1091
|
+
This function implements the transformers-style approach to batching:
|
|
1092
|
+
1. Group images with the same shape for efficient batch processing
|
|
1093
|
+
2. Process each group as a batch (no padding waste within groups)
|
|
1094
|
+
3. Track original image sizes for proper attention masking
|
|
1095
|
+
4. Restore results to original batch order
|
|
1096
|
+
|
|
1097
|
+
Key insight: Instead of padding all images to the same spatial dimensions
|
|
1098
|
+
(which wastes computation and may hurt accuracy), we group same-sized
|
|
1099
|
+
images together so there's zero padding within each group.
|
|
1100
|
+
|
|
1101
|
+
Args:
|
|
1102
|
+
model (nn.Module): The language model.
|
|
1103
|
+
processor (PreTrainedTokenizer): The tokenizer/processor.
|
|
1104
|
+
images (Union[str, List[str]]): Images (paths, URLs, or PIL images).
|
|
1105
|
+
audios (Union[str, List[str]]): Audio files (not yet supported for batching).
|
|
1106
|
+
prompts (List[str]): The input prompts.
|
|
1107
|
+
max_tokens (Union[int, List[int]]): Maximum number of output tokens. This
|
|
1108
|
+
can be per prompt if a list is provided.
|
|
1109
|
+
verbose (bool): If ``True``, print tokens and timing information.
|
|
1110
|
+
Default: ``False``.
|
|
1111
|
+
group_by_shape (bool): If ``True``, group same-shaped images for efficient
|
|
1112
|
+
batch processing. Default: ``True``.
|
|
1113
|
+
track_image_sizes (bool): If ``True``, track and return original image sizes.
|
|
1114
|
+
Default: ``True``.
|
|
1115
|
+
kwargs: The remaining options get passed to :obj:`BatchGenerator`.
|
|
1116
|
+
See :obj:`BatchGenerator` for more details.
|
|
1117
|
+
|
|
1118
|
+
Returns:
|
|
1119
|
+
BatchResponse with generated texts, statistics, and optionally image_sizes.
|
|
1120
|
+
"""
|
|
1121
|
+
from PIL import Image
|
|
1122
|
+
|
|
1123
|
+
from .utils import process_image
|
|
1124
|
+
|
|
1125
|
+
processor.detokenizer.reset()
|
|
1126
|
+
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
1127
|
+
|
|
1128
|
+
# Handle single image case
|
|
1129
|
+
if isinstance(images, str):
|
|
1130
|
+
images = [images]
|
|
1131
|
+
|
|
1132
|
+
# Handle no images case
|
|
1133
|
+
if images is None:
|
|
1134
|
+
texts, stats = _generate_batch(
|
|
1135
|
+
model, processor, prompts, None, max_tokens, verbose, **kwargs
|
|
1136
|
+
)
|
|
1137
|
+
return BatchResponse(texts, stats)
|
|
1138
|
+
|
|
1139
|
+
# Load and preprocess images
|
|
1140
|
+
image_processor = (
|
|
1141
|
+
processor.image_processor if hasattr(processor, "image_processor") else None
|
|
1142
|
+
)
|
|
1143
|
+
|
|
1144
|
+
processed_images = []
|
|
1145
|
+
image_sizes_original = []
|
|
1146
|
+
for img in images:
|
|
1147
|
+
if isinstance(img, str):
|
|
1148
|
+
pil_img = process_image(img, None, image_processor)
|
|
1149
|
+
elif isinstance(img, Image.Image):
|
|
1150
|
+
pil_img = img
|
|
1151
|
+
else:
|
|
1152
|
+
pil_img = img
|
|
1153
|
+
processed_images.append(pil_img)
|
|
1154
|
+
# Track original size
|
|
1155
|
+
if hasattr(pil_img, "height"):
|
|
1156
|
+
image_sizes_original.append((pil_img.height, pil_img.width))
|
|
1157
|
+
else:
|
|
1158
|
+
image_sizes_original.append((0, 0))
|
|
1159
|
+
|
|
1160
|
+
# Group images by shape for efficient processing (no padding within groups)
|
|
1161
|
+
if group_by_shape and len(processed_images) > 1:
|
|
1162
|
+
grouped_images, grouped_indices = group_images_by_shape(processed_images)
|
|
1163
|
+
|
|
1164
|
+
if verbose:
|
|
1165
|
+
print(f"[batch_generate] Found {len(grouped_images)} unique image shapes")
|
|
1166
|
+
else:
|
|
1167
|
+
# Single image or grouping disabled - treat as one group
|
|
1168
|
+
shape = (
|
|
1169
|
+
(processed_images[0].height, processed_images[0].width)
|
|
1170
|
+
if processed_images
|
|
1171
|
+
else (0, 0)
|
|
1172
|
+
)
|
|
1173
|
+
grouped_images = {shape: processed_images}
|
|
1174
|
+
grouped_indices = {shape: list(range(len(processed_images)))}
|
|
1175
|
+
|
|
1176
|
+
# Process each shape group
|
|
1177
|
+
all_texts = [None] * len(prompts)
|
|
1178
|
+
all_image_sizes = [None] * len(prompts)
|
|
1179
|
+
total_stats = BatchStats()
|
|
1180
|
+
|
|
1181
|
+
for shape, indices in grouped_indices.items():
|
|
1182
|
+
# Get images and prompts for this shape group
|
|
1183
|
+
group_images = [processed_images[i] for i in indices]
|
|
1184
|
+
group_prompts = [prompts[i] for i in indices]
|
|
1185
|
+
group_sizes = [image_sizes_original[i] for i in indices]
|
|
1186
|
+
|
|
1187
|
+
# Handle per-sample max_tokens
|
|
1188
|
+
if isinstance(max_tokens, list):
|
|
1189
|
+
group_max_tokens = [max_tokens[i] for i in indices]
|
|
1190
|
+
else:
|
|
1191
|
+
group_max_tokens = max_tokens
|
|
1192
|
+
|
|
1193
|
+
# Process the entire group at once (same shape = no padding needed)
|
|
1194
|
+
chunk_texts, chunk_stats = _generate_batch(
|
|
1195
|
+
model,
|
|
1196
|
+
processor,
|
|
1197
|
+
group_prompts,
|
|
1198
|
+
group_images,
|
|
1199
|
+
group_max_tokens,
|
|
1200
|
+
**kwargs,
|
|
1201
|
+
)
|
|
1202
|
+
|
|
1203
|
+
# Store results in original order
|
|
1204
|
+
for j, orig_idx in enumerate(indices):
|
|
1205
|
+
all_texts[orig_idx] = chunk_texts[j]
|
|
1206
|
+
all_image_sizes[orig_idx] = group_sizes[j]
|
|
1207
|
+
|
|
1208
|
+
# Accumulate stats
|
|
1209
|
+
total_stats.prompt_tokens += chunk_stats.prompt_tokens
|
|
1210
|
+
total_stats.prompt_time += chunk_stats.prompt_time
|
|
1211
|
+
total_stats.generation_tokens += chunk_stats.generation_tokens
|
|
1212
|
+
total_stats.generation_time += chunk_stats.generation_time
|
|
1213
|
+
|
|
1214
|
+
mx.clear_cache()
|
|
1215
|
+
|
|
1216
|
+
# Compute final stats
|
|
1217
|
+
if total_stats.prompt_time > 0:
|
|
1218
|
+
total_stats.prompt_tps = total_stats.prompt_tokens / total_stats.prompt_time
|
|
1219
|
+
if total_stats.generation_time > 0:
|
|
1220
|
+
total_stats.generation_tps = (
|
|
1221
|
+
total_stats.generation_tokens / total_stats.generation_time
|
|
1222
|
+
)
|
|
1223
|
+
total_stats.peak_memory = mx.get_peak_memory() / 1e9
|
|
1224
|
+
|
|
1225
|
+
if verbose:
|
|
1226
|
+
print(f"[batch_generate] Finished processing {len(prompts)} samples")
|
|
1227
|
+
print(
|
|
1228
|
+
f"[batch_generate] Prompt: {total_stats.prompt_tokens} tokens, {total_stats.prompt_tps:.3f} tokens-per-sec"
|
|
1229
|
+
)
|
|
1230
|
+
print(
|
|
1231
|
+
f"[batch_generate] Generation: {total_stats.generation_tokens} tokens, "
|
|
1232
|
+
f"{total_stats.generation_tps:.3f} tokens-per-sec"
|
|
1233
|
+
)
|
|
1234
|
+
print(f"[batch_generate] Peak memory: {total_stats.peak_memory:.3f} GB")
|
|
1235
|
+
|
|
1236
|
+
response = BatchResponse(all_texts, total_stats)
|
|
1237
|
+
if track_image_sizes:
|
|
1238
|
+
response.image_sizes = all_image_sizes
|
|
1239
|
+
return response
|
|
1240
|
+
|
|
1241
|
+
|
|
1242
|
+
def _generate_batch(
|
|
1243
|
+
model,
|
|
1244
|
+
processor,
|
|
1245
|
+
prompts: List[str],
|
|
1246
|
+
images: List = None,
|
|
1247
|
+
max_tokens: Union[int, List[int]] = 100,
|
|
1248
|
+
verbose: bool = False,
|
|
1249
|
+
**kwargs,
|
|
1250
|
+
) -> Tuple[List[str], BatchStats]:
|
|
1251
|
+
|
|
1252
|
+
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
1253
|
+
batch_size = len(prompts)
|
|
1254
|
+
|
|
1255
|
+
num_images_list = [
|
|
1256
|
+
1 if i < (len(images) if images is not None else 0) else 0
|
|
1257
|
+
for i in range(len(prompts))
|
|
1258
|
+
]
|
|
1259
|
+
formatted_prompts = [
|
|
1260
|
+
apply_chat_template(
|
|
1261
|
+
processor,
|
|
1262
|
+
model.config,
|
|
1263
|
+
p,
|
|
1264
|
+
num_images=num_images_list[i],
|
|
1265
|
+
)
|
|
1266
|
+
for i, p in enumerate(prompts)
|
|
1267
|
+
]
|
|
1268
|
+
|
|
1269
|
+
add_special_tokens = (
|
|
1270
|
+
not hasattr(processor, "chat_template")
|
|
1271
|
+
if model.config.model_type in ["gemma3", "gemma3n"]
|
|
1272
|
+
else True
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1275
|
+
resize_shape = kwargs.pop("resize_shape", None)
|
|
1276
|
+
image_token_index = getattr(model.config, "image_token_index", None)
|
|
1277
|
+
|
|
1278
|
+
inputs = prepare_inputs(
|
|
1279
|
+
processor,
|
|
1280
|
+
images=images,
|
|
1281
|
+
audio=None,
|
|
1282
|
+
prompts=formatted_prompts,
|
|
1283
|
+
image_token_index=image_token_index,
|
|
1284
|
+
resize_shape=resize_shape,
|
|
1285
|
+
add_special_tokens=add_special_tokens,
|
|
1286
|
+
pad_to_uniform_size=False, # Since images are pre-grouped by shape, they're already uniform size
|
|
1287
|
+
)
|
|
1288
|
+
input_ids = inputs.get("input_ids", None)
|
|
1289
|
+
pixel_values = inputs.get("pixel_values", None)
|
|
1290
|
+
|
|
1291
|
+
data_kwargs = {
|
|
1292
|
+
k: v
|
|
1293
|
+
for k, v in inputs.items()
|
|
1294
|
+
if k not in ["input_ids", "pixel_values", "attention_mask"]
|
|
1295
|
+
}
|
|
1296
|
+
|
|
1297
|
+
# Use batch_size for prefill and completion to ensure consistent processing
|
|
1298
|
+
gen = BatchGenerator(
|
|
1299
|
+
model.language_model,
|
|
1300
|
+
processor,
|
|
1301
|
+
prefill_batch_size=batch_size,
|
|
1302
|
+
completion_batch_size=batch_size,
|
|
1303
|
+
**kwargs,
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
with wired_limit(model, [generation_stream]):
|
|
1307
|
+
if pixel_values is not None:
|
|
1308
|
+
embedding_output = model.get_input_embeddings(
|
|
1309
|
+
input_ids, pixel_values, **data_kwargs
|
|
1310
|
+
)
|
|
1311
|
+
|
|
1312
|
+
# Normalize embedding output to a kwargs dict expected by BatchGenerator
|
|
1313
|
+
if isinstance(embedding_output, dict):
|
|
1314
|
+
embed_kwargs = embedding_output
|
|
1315
|
+
elif hasattr(embedding_output, "to_dict"):
|
|
1316
|
+
# Convert to dict and keep non-None fields
|
|
1317
|
+
embed_kwargs = {
|
|
1318
|
+
k: v for k, v in embedding_output.to_dict().items() if v is not None
|
|
1319
|
+
}
|
|
1320
|
+
else:
|
|
1321
|
+
# Assume it's directly an inputs_embeds array
|
|
1322
|
+
embed_kwargs = {"inputs_embeds": embedding_output}
|
|
1323
|
+
|
|
1324
|
+
gen_kwargs = {
|
|
1325
|
+
"pixel_values": pixel_values,
|
|
1326
|
+
**data_kwargs,
|
|
1327
|
+
**embed_kwargs,
|
|
1328
|
+
}
|
|
1329
|
+
else:
|
|
1330
|
+
input_ids = mx.squeeze(input_ids, axis=0)
|
|
1331
|
+
gen_kwargs = {}
|
|
1332
|
+
|
|
1333
|
+
uids = gen.insert(input_ids.tolist(), max_tokens)
|
|
1334
|
+
results = {uid: [] for uid in uids}
|
|
1335
|
+
while responses := gen.next(**gen_kwargs):
|
|
1336
|
+
for r in responses:
|
|
1337
|
+
if r.finish_reason != "stop":
|
|
1338
|
+
results[r.uid].append(r.token)
|
|
1339
|
+
|
|
1340
|
+
texts = [tokenizer.decode(results[uid]) for uid in uids]
|
|
1341
|
+
return texts, gen.stats()
|
|
1342
|
+
|
|
1343
|
+
|
|
1344
|
+
def main():
|
|
1345
|
+
args = parse_arguments()
|
|
1346
|
+
if isinstance(args.image, str):
|
|
1347
|
+
args.image = [args.image]
|
|
1348
|
+
|
|
1349
|
+
model, processor = load(
|
|
1350
|
+
args.model,
|
|
1351
|
+
args.adapter_path,
|
|
1352
|
+
revision=args.revision,
|
|
1353
|
+
trust_remote_code=args.trust_remote_code,
|
|
1354
|
+
)
|
|
1355
|
+
config = model.config
|
|
1356
|
+
|
|
1357
|
+
prompt = args.prompt
|
|
1358
|
+
|
|
1359
|
+
num_images = len(args.image) if args.image is not None else 0
|
|
1360
|
+
num_audios = (
|
|
1361
|
+
1 if args.audio is not None else 0
|
|
1362
|
+
) # TODO: Support multiple audio files
|
|
1363
|
+
prompt = apply_chat_template(
|
|
1364
|
+
processor, config, prompt, num_images=num_images, num_audios=num_audios
|
|
1365
|
+
)
|
|
1366
|
+
|
|
1367
|
+
kwargs = {}
|
|
1368
|
+
|
|
1369
|
+
if args.resize_shape is not None:
|
|
1370
|
+
if len(args.resize_shape) not in [1, 2]:
|
|
1371
|
+
raise ValueError("Resize shape must be 1 or 2 integers")
|
|
1372
|
+
kwargs["resize_shape"] = (
|
|
1373
|
+
(args.resize_shape[0],) * 2
|
|
1374
|
+
if len(args.resize_shape) == 1
|
|
1375
|
+
else tuple(args.resize_shape)
|
|
1376
|
+
)
|
|
1377
|
+
|
|
1378
|
+
if args.eos_tokens is not None:
|
|
1379
|
+
eos_tokens = []
|
|
1380
|
+
for token in args.eos_tokens:
|
|
1381
|
+
try:
|
|
1382
|
+
decoded_token = codecs.decode(token, "unicode_escape")
|
|
1383
|
+
eos_tokens.append(decoded_token)
|
|
1384
|
+
except (UnicodeDecodeError, UnicodeError):
|
|
1385
|
+
eos_tokens.append(token)
|
|
1386
|
+
kwargs["eos_tokens"] = eos_tokens
|
|
1387
|
+
|
|
1388
|
+
if args.skip_special_tokens:
|
|
1389
|
+
kwargs["skip_special_tokens"] = args.skip_special_tokens
|
|
1390
|
+
|
|
1391
|
+
# Add processor kwargs from JSON
|
|
1392
|
+
if args.processor_kwargs:
|
|
1393
|
+
kwargs.update(args.processor_kwargs)
|
|
1394
|
+
|
|
1395
|
+
if args.chat:
|
|
1396
|
+
chat = []
|
|
1397
|
+
if args.system:
|
|
1398
|
+
chat.append({"role": "system", "content": args.system})
|
|
1399
|
+
while user := input("User:"):
|
|
1400
|
+
chat.append({"role": "user", "content": user})
|
|
1401
|
+
prompt = apply_chat_template(processor, config, chat, num_images=num_images)
|
|
1402
|
+
response = ""
|
|
1403
|
+
print("Assistant:", end="")
|
|
1404
|
+
stream_kwargs = {
|
|
1405
|
+
"max_tokens": args.max_tokens,
|
|
1406
|
+
"temperature": args.temperature,
|
|
1407
|
+
**kwargs,
|
|
1408
|
+
}
|
|
1409
|
+
if args.prefill_step_size is not None:
|
|
1410
|
+
stream_kwargs["prefill_step_size"] = args.prefill_step_size
|
|
1411
|
+
|
|
1412
|
+
for chunk in stream_generate(
|
|
1413
|
+
model,
|
|
1414
|
+
processor,
|
|
1415
|
+
prompt,
|
|
1416
|
+
args.image,
|
|
1417
|
+
args.audio,
|
|
1418
|
+
**stream_kwargs,
|
|
1419
|
+
):
|
|
1420
|
+
response += chunk.text
|
|
1421
|
+
print(chunk.text, end="")
|
|
1422
|
+
|
|
1423
|
+
chat.append({"role": "assistant", "content": response})
|
|
1424
|
+
print()
|
|
1425
|
+
|
|
1426
|
+
else:
|
|
1427
|
+
gen_kwargs = {
|
|
1428
|
+
"image": args.image,
|
|
1429
|
+
"audio": args.audio,
|
|
1430
|
+
"temperature": args.temperature,
|
|
1431
|
+
"max_tokens": args.max_tokens,
|
|
1432
|
+
"verbose": args.verbose,
|
|
1433
|
+
"max_kv_size": args.max_kv_size,
|
|
1434
|
+
"kv_bits": args.kv_bits,
|
|
1435
|
+
"kv_group_size": args.kv_group_size,
|
|
1436
|
+
"quantized_kv_start": args.quantized_kv_start,
|
|
1437
|
+
**kwargs,
|
|
1438
|
+
}
|
|
1439
|
+
if args.prefill_step_size is not None:
|
|
1440
|
+
gen_kwargs["prefill_step_size"] = args.prefill_step_size
|
|
1441
|
+
|
|
1442
|
+
result = generate(
|
|
1443
|
+
model,
|
|
1444
|
+
processor,
|
|
1445
|
+
prompt,
|
|
1446
|
+
**gen_kwargs,
|
|
1447
|
+
)
|
|
1448
|
+
if not args.verbose:
|
|
1449
|
+
print(result.text)
|
|
1450
|
+
|
|
1451
|
+
|
|
1452
|
+
if __name__ == "__main__":
|
|
1453
|
+
print(
|
|
1454
|
+
"Calling `python -m mlx_vlm.generate ...` directly is deprecated."
|
|
1455
|
+
" Use `mlx_vlm generate` or `python -m mlx_vlm generate` instead."
|
|
1456
|
+
)
|
|
1457
|
+
main()
|