fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
mlx_vlm/server.py
ADDED
|
@@ -0,0 +1,1107 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import asyncio
|
|
3
|
+
import gc
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import traceback
|
|
7
|
+
import uuid
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, List, Literal, Optional, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import mlx.core as mx
|
|
12
|
+
import uvicorn
|
|
13
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
14
|
+
from fastapi.responses import StreamingResponse
|
|
15
|
+
from huggingface_hub import scan_cache_dir
|
|
16
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
17
|
+
from typing_extensions import Required, TypeAlias, TypedDict
|
|
18
|
+
|
|
19
|
+
from .generate import (
|
|
20
|
+
DEFAULT_MAX_TOKENS,
|
|
21
|
+
DEFAULT_MODEL_PATH,
|
|
22
|
+
DEFAULT_SEED,
|
|
23
|
+
DEFAULT_TEMPERATURE,
|
|
24
|
+
DEFAULT_TOP_P,
|
|
25
|
+
generate,
|
|
26
|
+
stream_generate,
|
|
27
|
+
)
|
|
28
|
+
from .prompt_utils import apply_chat_template
|
|
29
|
+
from .utils import load
|
|
30
|
+
from .version import __version__
|
|
31
|
+
|
|
32
|
+
app = FastAPI(
|
|
33
|
+
title="MLX-VLM Inference API",
|
|
34
|
+
description="API for using Vision Language Models (VLMs) and Omni Models (Vision, Audio and Video support) with MLX.",
|
|
35
|
+
version=__version__,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
MAX_IMAGES = 10 # Maximum number of images to process at once
|
|
39
|
+
|
|
40
|
+
# Loading/unloading utilities
|
|
41
|
+
|
|
42
|
+
model_cache = {}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class FlexibleBaseModel(BaseModel):
|
|
46
|
+
"""Base model that ignores/accepts any unknown OpenAI SDK fields."""
|
|
47
|
+
|
|
48
|
+
model_config = ConfigDict(extra="allow")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_model_resources(model_path: str, adapter_path: Optional[str]):
|
|
52
|
+
"""
|
|
53
|
+
Loads model, processor, and config based on paths.
|
|
54
|
+
Handles potential loading errors.
|
|
55
|
+
"""
|
|
56
|
+
try:
|
|
57
|
+
print(f"Loading model from: {model_path}")
|
|
58
|
+
if adapter_path:
|
|
59
|
+
print(f"Loading adapter from: {adapter_path}")
|
|
60
|
+
# Use the load function from utils.py which handles path resolution and loading
|
|
61
|
+
trust_remote_code = (
|
|
62
|
+
os.environ.get("MLX_TRUST_REMOTE_CODE", "false").lower() == "true"
|
|
63
|
+
)
|
|
64
|
+
model, processor = load(
|
|
65
|
+
model_path, adapter_path, trust_remote_code=trust_remote_code
|
|
66
|
+
)
|
|
67
|
+
config = model.config
|
|
68
|
+
print("Model and processor loaded successfully.")
|
|
69
|
+
return model, processor, config
|
|
70
|
+
except Exception as e:
|
|
71
|
+
print(f"Error loading model {model_path}: {e}")
|
|
72
|
+
traceback.print_exc() # Print detailed traceback for debugging
|
|
73
|
+
raise HTTPException(status_code=500, detail=f"Failed to load model: {e}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_cached_model(model_path: str, adapter_path: Optional[str] = None):
|
|
77
|
+
"""
|
|
78
|
+
Factory function to get or load the appropriate model resources from cache or by loading.
|
|
79
|
+
"""
|
|
80
|
+
global model_cache
|
|
81
|
+
|
|
82
|
+
cache_key = (model_path, adapter_path)
|
|
83
|
+
|
|
84
|
+
# Return from cache if already loaded and matches the requested paths
|
|
85
|
+
if model_cache.get("cache_key") == cache_key:
|
|
86
|
+
print(f"Using cached model: {model_path}, Adapter: {adapter_path}")
|
|
87
|
+
return model_cache["model"], model_cache["processor"], model_cache["config"]
|
|
88
|
+
|
|
89
|
+
# If cache exists but doesn't match, clear it
|
|
90
|
+
if model_cache:
|
|
91
|
+
print("New model request, clearing existing cache...")
|
|
92
|
+
unload_model_sync() # Use a synchronous version for internal call
|
|
93
|
+
|
|
94
|
+
# Load the model resources
|
|
95
|
+
model, processor, config = load_model_resources(model_path, adapter_path)
|
|
96
|
+
|
|
97
|
+
model_cache = {
|
|
98
|
+
"cache_key": cache_key,
|
|
99
|
+
"model_path": model_path,
|
|
100
|
+
"adapter_path": adapter_path,
|
|
101
|
+
"model": model,
|
|
102
|
+
"processor": processor,
|
|
103
|
+
"config": config,
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
return model, processor, config
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# Synchronous unload function for internal use
|
|
110
|
+
def unload_model_sync():
|
|
111
|
+
global model_cache
|
|
112
|
+
if not model_cache:
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
print(
|
|
116
|
+
f"Unloading model: {model_cache.get('model_path')}, Adapter: {model_cache.get('adapter_path')}"
|
|
117
|
+
)
|
|
118
|
+
# Clear references
|
|
119
|
+
model_cache = {}
|
|
120
|
+
# Force garbage collection
|
|
121
|
+
gc.collect()
|
|
122
|
+
mx.clear_cache()
|
|
123
|
+
print("Model unloaded and cache cleared.")
|
|
124
|
+
return True
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# OpenAI API Models
|
|
128
|
+
|
|
129
|
+
# Models for /responses endpoint
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class ResponseInputTextParam(TypedDict, total=False):
|
|
133
|
+
text: Required[str]
|
|
134
|
+
type: Required[
|
|
135
|
+
Literal["input_text", "text"]
|
|
136
|
+
] # The type of the input item. Always `input_text`.
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class ResponseInputImageParam(TypedDict, total=False):
|
|
140
|
+
detail: Literal["high", "low", "auto"] = Field(
|
|
141
|
+
"auto", description="The detail level of the image to be sent to the model."
|
|
142
|
+
)
|
|
143
|
+
"""The detail level of the image to be sent to the model.
|
|
144
|
+
|
|
145
|
+
One of `high`, `low`, or `auto`. Defaults to `auto`.
|
|
146
|
+
"""
|
|
147
|
+
type: Required[
|
|
148
|
+
Literal["input_image"]
|
|
149
|
+
] # The type of the input item. Always `input_image`.
|
|
150
|
+
image_url: Required[str]
|
|
151
|
+
file_id: Optional[str]
|
|
152
|
+
"""The ID of the file to be sent to the model.
|
|
153
|
+
NOTE : wouldn't this help the model if we passed the file_id as well to the vlm models
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class InputAudio(TypedDict, total=False):
|
|
158
|
+
data: Required[str]
|
|
159
|
+
format: Required[str]
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class ResponseInputAudioParam(TypedDict, total=False):
|
|
163
|
+
type: Required[
|
|
164
|
+
Literal["input_audio"]
|
|
165
|
+
] # The type of the input item. Always `input_audio`.
|
|
166
|
+
input_audio: Required[InputAudio]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class ImageUrl(TypedDict, total=False):
|
|
170
|
+
url: Required[str]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class ResponseImageUrlParam(TypedDict, total=False):
|
|
174
|
+
type: Required[
|
|
175
|
+
Literal["image_url"]
|
|
176
|
+
] # The type of the input item. Always`image_url`.
|
|
177
|
+
image_url: Required[ImageUrl]
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
ResponseInputContentParam: TypeAlias = Union[
|
|
181
|
+
ResponseInputTextParam,
|
|
182
|
+
ResponseInputImageParam,
|
|
183
|
+
ResponseImageUrlParam,
|
|
184
|
+
ResponseInputAudioParam,
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
ResponseInputMessageContentListParam: TypeAlias = List[ResponseInputContentParam]
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class ResponseOutputText(TypedDict, total=False):
|
|
191
|
+
text: Required[str]
|
|
192
|
+
type: Required[
|
|
193
|
+
Literal["output_text"]
|
|
194
|
+
] # The type of the output item. Always `output_text`
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
ResponseOutputMessageContentList: TypeAlias = List[ResponseOutputText]
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class ChatMessage(FlexibleBaseModel):
|
|
201
|
+
role: Literal["user", "assistant", "system", "developer"] = Field(
|
|
202
|
+
...,
|
|
203
|
+
description="Role of the message sender (e.g., 'system', 'user', 'assistant').",
|
|
204
|
+
)
|
|
205
|
+
content: Union[
|
|
206
|
+
str, ResponseInputMessageContentListParam, ResponseOutputMessageContentList
|
|
207
|
+
] = Field(..., description="Content of the message.")
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class OpenAIRequest(FlexibleBaseModel):
|
|
211
|
+
"""
|
|
212
|
+
OpenAI-compatible request structure.
|
|
213
|
+
Using this structure : https://github.com/openai/openai-python/blob/main/src/openai/resources/responses/responses.py
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
input: Union[str, List[ChatMessage]] = Field(
|
|
217
|
+
..., description="Input text or list of chat messages."
|
|
218
|
+
)
|
|
219
|
+
model: str = Field(..., description="The model to use for generation.")
|
|
220
|
+
max_output_tokens: int = Field(
|
|
221
|
+
DEFAULT_MAX_TOKENS, description="Maximum number of tokens to generate."
|
|
222
|
+
)
|
|
223
|
+
temperature: float = Field(
|
|
224
|
+
DEFAULT_TEMPERATURE, description="Temperature for sampling."
|
|
225
|
+
)
|
|
226
|
+
top_p: float = Field(DEFAULT_TOP_P, description="Top-p sampling.")
|
|
227
|
+
stream: bool = Field(
|
|
228
|
+
False, description="Whether to stream the response chunk by chunk."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class OpenAIUsage(BaseModel):
|
|
233
|
+
"""Token usage details including input tokens, output tokens, breakdown, and total tokens used."""
|
|
234
|
+
|
|
235
|
+
input_tokens: int
|
|
236
|
+
output_tokens: int
|
|
237
|
+
total_tokens: int
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class OpenAIErrorObject(BaseModel):
|
|
241
|
+
"""Error object returned when the model fails to generate a Response."""
|
|
242
|
+
|
|
243
|
+
code: Optional[str] = None
|
|
244
|
+
message: Optional[str] = None
|
|
245
|
+
param: Optional[str] = None
|
|
246
|
+
type: Optional[str] = None
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class OpenAIResponse(BaseModel):
|
|
250
|
+
id: str = Field(..., description="Unique identifier for this Response")
|
|
251
|
+
object: Literal["response"] = Field(
|
|
252
|
+
..., description="The object type of this resource - always set to response"
|
|
253
|
+
)
|
|
254
|
+
created_at: int = Field(
|
|
255
|
+
..., description="Unix timestamp (in seconds) of when this Response was created"
|
|
256
|
+
)
|
|
257
|
+
status: Literal["completed", "failed", "in_progress", "incomplete"] = Field(
|
|
258
|
+
..., description="The status of the response generation"
|
|
259
|
+
)
|
|
260
|
+
error: Optional[OpenAIErrorObject] = Field(
|
|
261
|
+
None,
|
|
262
|
+
description="An error object returned when the model fails to generate a Response",
|
|
263
|
+
)
|
|
264
|
+
instructions: Optional[str] = Field(
|
|
265
|
+
None,
|
|
266
|
+
description="Inserts a system (or developer) message as the first item in the model's context",
|
|
267
|
+
)
|
|
268
|
+
max_output_tokens: Optional[int] = Field(
|
|
269
|
+
None,
|
|
270
|
+
description="An upper bound for the number of tokens that can be generated for a response",
|
|
271
|
+
)
|
|
272
|
+
model: str = Field(..., description="Model ID used to generate the response")
|
|
273
|
+
output: List[Union[ChatMessage, Any]] = Field(
|
|
274
|
+
..., description="An array of content items generated by the model"
|
|
275
|
+
)
|
|
276
|
+
output_text: Optional[str] = Field(
|
|
277
|
+
None,
|
|
278
|
+
description="SDK-only convenience property containing aggregated text output",
|
|
279
|
+
)
|
|
280
|
+
temperature: Optional[float] = Field(
|
|
281
|
+
None, ge=0, le=2, description="Sampling temperature between 0 and 2"
|
|
282
|
+
)
|
|
283
|
+
top_p: Optional[float] = Field(
|
|
284
|
+
None, ge=0, le=1, description="Nucleus sampling probability mass"
|
|
285
|
+
)
|
|
286
|
+
truncation: Union[Literal["auto", "disabled"], str] = Field(
|
|
287
|
+
"disabled", description="The truncation strategy to use"
|
|
288
|
+
)
|
|
289
|
+
usage: OpenAIUsage = Field(
|
|
290
|
+
..., description="Token usage details"
|
|
291
|
+
) # we need the model to return stats
|
|
292
|
+
user: Optional[str] = Field(
|
|
293
|
+
None, description="A unique identifier representing your end-user"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class BaseStreamEvent(BaseModel):
|
|
298
|
+
type: str
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class ContentPartOutputText(BaseModel):
|
|
302
|
+
type: Literal["output_text"]
|
|
303
|
+
text: str
|
|
304
|
+
annotations: List[str] = []
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class MessageItem(BaseModel):
|
|
308
|
+
id: str
|
|
309
|
+
type: Literal["message"]
|
|
310
|
+
status: Literal["in_progress", "completed"]
|
|
311
|
+
role: str
|
|
312
|
+
content: List[ContentPartOutputText] = []
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class ResponseCreatedEvent(BaseStreamEvent):
|
|
316
|
+
type: Literal["response.created"]
|
|
317
|
+
response: OpenAIResponse
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class ResponseInProgressEvent(BaseStreamEvent):
|
|
321
|
+
type: Literal["response.in_progress"]
|
|
322
|
+
response: OpenAIResponse
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class ResponseOutputItemAddedEvent(BaseStreamEvent):
|
|
326
|
+
type: Literal["response.output_item.added"]
|
|
327
|
+
output_index: int
|
|
328
|
+
item: MessageItem
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class ResponseContentPartAddedEvent(BaseStreamEvent):
|
|
332
|
+
type: Literal["response.content_part.added"]
|
|
333
|
+
item_id: str
|
|
334
|
+
output_index: int
|
|
335
|
+
content_index: int
|
|
336
|
+
part: ContentPartOutputText
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class ResponseOutputTextDeltaEvent(BaseStreamEvent):
|
|
340
|
+
type: Literal["response.output_text.delta"]
|
|
341
|
+
item_id: str
|
|
342
|
+
output_index: int
|
|
343
|
+
content_index: int
|
|
344
|
+
delta: str
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class ResponseOutputTextDoneEvent(BaseStreamEvent):
|
|
348
|
+
type: Literal["response.output_text.done"]
|
|
349
|
+
item_id: str
|
|
350
|
+
output_index: int
|
|
351
|
+
content_index: int
|
|
352
|
+
text: str
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class ResponseContentPartDoneEvent(BaseStreamEvent):
|
|
356
|
+
type: Literal["response.content_part.done"]
|
|
357
|
+
item_id: str
|
|
358
|
+
output_index: int
|
|
359
|
+
content_index: int
|
|
360
|
+
part: ContentPartOutputText
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class ResponseOutputItemDoneEvent(BaseStreamEvent):
|
|
364
|
+
type: Literal["response.output_item.done"]
|
|
365
|
+
output_index: int
|
|
366
|
+
item: MessageItem
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class ResponseCompletedEvent(BaseStreamEvent):
|
|
370
|
+
type: Literal["response.completed"]
|
|
371
|
+
response: OpenAIResponse
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
StreamEvent = Union[
|
|
375
|
+
ResponseCreatedEvent,
|
|
376
|
+
ResponseInProgressEvent,
|
|
377
|
+
ResponseOutputItemAddedEvent,
|
|
378
|
+
ResponseContentPartAddedEvent,
|
|
379
|
+
ResponseOutputTextDeltaEvent,
|
|
380
|
+
ResponseOutputTextDoneEvent,
|
|
381
|
+
ResponseContentPartDoneEvent,
|
|
382
|
+
ResponseOutputItemDoneEvent,
|
|
383
|
+
ResponseCompletedEvent,
|
|
384
|
+
]
|
|
385
|
+
|
|
386
|
+
# Models for /chat/completion endpoint
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
class VLMRequest(FlexibleBaseModel):
|
|
390
|
+
model: str = Field(
|
|
391
|
+
DEFAULT_MODEL_PATH,
|
|
392
|
+
description="The path to the local model directory or Hugging Face repo.",
|
|
393
|
+
)
|
|
394
|
+
adapter_path: Optional[str] = Field(
|
|
395
|
+
None, description="The path to the adapter weights."
|
|
396
|
+
)
|
|
397
|
+
max_tokens: int = Field(
|
|
398
|
+
DEFAULT_MAX_TOKENS, description="Maximum number of tokens to generate."
|
|
399
|
+
)
|
|
400
|
+
temperature: float = Field(
|
|
401
|
+
DEFAULT_TEMPERATURE, description="Temperature for sampling."
|
|
402
|
+
)
|
|
403
|
+
top_p: float = Field(DEFAULT_TOP_P, description="Top-p sampling.")
|
|
404
|
+
seed: int = Field(DEFAULT_SEED, description="Seed for random generation.")
|
|
405
|
+
resize_shape: Optional[Tuple[int, int]] = Field(
|
|
406
|
+
None,
|
|
407
|
+
description="Resize shape for the image (height, width). Provide two integers.",
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class GenerationRequest(VLMRequest):
|
|
412
|
+
"""
|
|
413
|
+
Inherits from VLMRequest and adds additional fields for the generation request.
|
|
414
|
+
"""
|
|
415
|
+
|
|
416
|
+
stream: bool = Field(
|
|
417
|
+
False, description="Whether to stream the response chunk by chunk."
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
class UsageStats(OpenAIUsage):
|
|
422
|
+
"""
|
|
423
|
+
Inherits from OpenAIUsage and adds additional fields for usage statistics.
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
prompt_tps: float = Field(..., description="Tokens per second for the prompt.")
|
|
427
|
+
generation_tps: float = Field(
|
|
428
|
+
..., description="Tokens per second for the generation."
|
|
429
|
+
)
|
|
430
|
+
peak_memory: float = Field(
|
|
431
|
+
..., description="Peak memory usage during the generation."
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class ChatRequest(GenerationRequest):
|
|
436
|
+
messages: List[ChatMessage]
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
class ChatChoice(BaseModel):
|
|
440
|
+
finish_reason: str
|
|
441
|
+
message: ChatMessage
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
class ChatResponse(BaseModel):
|
|
445
|
+
model: str
|
|
446
|
+
choices: List[ChatChoice]
|
|
447
|
+
usage: Optional[UsageStats]
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
class ChatStreamChoice(BaseModel):
|
|
451
|
+
finish_reason: Optional[str] = None
|
|
452
|
+
delta: ChatMessage
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
class ChatStreamChunk(BaseModel):
|
|
456
|
+
model: str
|
|
457
|
+
choices: List[ChatStreamChoice]
|
|
458
|
+
usage: Optional[UsageStats]
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
# Models for /models endpoint
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
class ModelInfo(BaseModel):
|
|
465
|
+
id: str
|
|
466
|
+
object: str
|
|
467
|
+
created: int
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
class ModelsResponse(BaseModel):
|
|
471
|
+
object: Literal["list"]
|
|
472
|
+
data: List[ModelInfo]
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
# OpenAI compatile endpoints
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@app.post("/responses")
|
|
479
|
+
async def responses_endpoint(request: Request):
|
|
480
|
+
"""
|
|
481
|
+
OpenAI-compatible endpoint for generating text based on a prompt and optional images.
|
|
482
|
+
|
|
483
|
+
using client.responses.create method.
|
|
484
|
+
|
|
485
|
+
example:
|
|
486
|
+
|
|
487
|
+
from openai import OpenAI
|
|
488
|
+
|
|
489
|
+
API_URL = "http://0.0.0.0:8000"
|
|
490
|
+
API_KEY = 'any'
|
|
491
|
+
|
|
492
|
+
def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, model="mlx-community/Qwen2.5-VL-3B-Instruct-8bit"):
|
|
493
|
+
''' Calls the OpenAI API
|
|
494
|
+
'''
|
|
495
|
+
|
|
496
|
+
client = OpenAI(base_url=f"{API_URL}", api_key=API_KEY)
|
|
497
|
+
|
|
498
|
+
try :
|
|
499
|
+
response = client.responses.create(
|
|
500
|
+
model=model,
|
|
501
|
+
input=[
|
|
502
|
+
{"role":"system",
|
|
503
|
+
"content": f"{system}"
|
|
504
|
+
},
|
|
505
|
+
{
|
|
506
|
+
"role": "user",
|
|
507
|
+
"content": [
|
|
508
|
+
{"type": "input_text", "text": prompt},
|
|
509
|
+
{"type": "input_image", "image_url": f"{img_url}"},
|
|
510
|
+
],
|
|
511
|
+
}
|
|
512
|
+
],
|
|
513
|
+
max_output_tokens=max_output_tokens,
|
|
514
|
+
stream=stream
|
|
515
|
+
)
|
|
516
|
+
if not stream:
|
|
517
|
+
print(response.output[0].content[0].text)
|
|
518
|
+
print(response.usage)
|
|
519
|
+
else:
|
|
520
|
+
for event in response:
|
|
521
|
+
# Process different event types if needed
|
|
522
|
+
if hasattr(event, 'delta') and event.delta:
|
|
523
|
+
print(event.delta, end="", flush=True)
|
|
524
|
+
elif event.type == 'response.completed':
|
|
525
|
+
print("\n--- Usage ---")
|
|
526
|
+
print(event.response.usage)
|
|
527
|
+
|
|
528
|
+
except Exception as e:
|
|
529
|
+
# building a response object to match the one returned when request is successful so that it can be processed in the same way
|
|
530
|
+
return {"model - error":str(e),"content":{}, "model":model}
|
|
531
|
+
|
|
532
|
+
"""
|
|
533
|
+
|
|
534
|
+
body = await request.json()
|
|
535
|
+
openai_request = OpenAIRequest(**body)
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
# Get model, processor, config - loading if necessary
|
|
539
|
+
model, processor, config = get_cached_model(openai_request.model)
|
|
540
|
+
|
|
541
|
+
kwargs = {}
|
|
542
|
+
|
|
543
|
+
chat_messages = []
|
|
544
|
+
images = []
|
|
545
|
+
instructions = None
|
|
546
|
+
if openai_request.input:
|
|
547
|
+
if isinstance(openai_request.input, str):
|
|
548
|
+
# If input is a string, treat it as a single text message
|
|
549
|
+
chat_messages.append({"role": "user", "content": openai_request.input})
|
|
550
|
+
elif isinstance(openai_request.input, list):
|
|
551
|
+
# If input is a list, treat it as a series of chat messages
|
|
552
|
+
for message in openai_request.input:
|
|
553
|
+
if isinstance(message, ChatMessage):
|
|
554
|
+
if isinstance(message.content, str):
|
|
555
|
+
chat_messages.append(
|
|
556
|
+
{"role": message.role, "content": message.content}
|
|
557
|
+
)
|
|
558
|
+
if message.role == "system":
|
|
559
|
+
instructions = message.content
|
|
560
|
+
elif isinstance(message.content, list):
|
|
561
|
+
# Handle list of content items
|
|
562
|
+
for item in message.content:
|
|
563
|
+
if isinstance(item, dict):
|
|
564
|
+
if item["type"] == "input_text":
|
|
565
|
+
chat_messages.append(
|
|
566
|
+
{
|
|
567
|
+
"role": message.role,
|
|
568
|
+
"content": item["text"],
|
|
569
|
+
}
|
|
570
|
+
)
|
|
571
|
+
if message.role == "system":
|
|
572
|
+
instructions = item["text"]
|
|
573
|
+
# examples for multiple images (https://platform.openai.com/docs/guides/images?api-mode=responses)
|
|
574
|
+
elif item["type"] == "input_image":
|
|
575
|
+
images.append(item["image_url"])
|
|
576
|
+
else:
|
|
577
|
+
print(
|
|
578
|
+
f"invalid input item type: {item['type']}"
|
|
579
|
+
)
|
|
580
|
+
raise HTTPException(
|
|
581
|
+
status_code=400,
|
|
582
|
+
detail="Invalid input item type.",
|
|
583
|
+
)
|
|
584
|
+
else:
|
|
585
|
+
print(
|
|
586
|
+
f"Invalid message content item format: {item}"
|
|
587
|
+
)
|
|
588
|
+
raise HTTPException(
|
|
589
|
+
status_code=400,
|
|
590
|
+
detail="Missing type in input item.",
|
|
591
|
+
)
|
|
592
|
+
else:
|
|
593
|
+
print("Invalid message content format.")
|
|
594
|
+
raise HTTPException(
|
|
595
|
+
status_code=400, detail="Invalid input format."
|
|
596
|
+
)
|
|
597
|
+
else:
|
|
598
|
+
print("not a ChatMessage")
|
|
599
|
+
raise HTTPException(
|
|
600
|
+
status_code=400, detail="Invalid input format."
|
|
601
|
+
)
|
|
602
|
+
else:
|
|
603
|
+
print("neither string not list")
|
|
604
|
+
raise HTTPException(status_code=400, detail="Invalid input format.")
|
|
605
|
+
|
|
606
|
+
else:
|
|
607
|
+
print("no input")
|
|
608
|
+
raise HTTPException(status_code=400, detail="Missing input.")
|
|
609
|
+
|
|
610
|
+
formatted_prompt = apply_chat_template(
|
|
611
|
+
processor, config, chat_messages, num_images=len(images)
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
generated_at = datetime.now().timestamp()
|
|
615
|
+
response_id = f"resp_{uuid.uuid4().hex}"
|
|
616
|
+
message_id = f"msg_{uuid.uuid4().hex}"
|
|
617
|
+
|
|
618
|
+
if openai_request.stream:
|
|
619
|
+
# Streaming response
|
|
620
|
+
async def stream_generator():
|
|
621
|
+
token_iterator = None
|
|
622
|
+
try:
|
|
623
|
+
# Create base response object (to match the openai pipeline)
|
|
624
|
+
base_response = OpenAIResponse(
|
|
625
|
+
id=response_id,
|
|
626
|
+
object="response",
|
|
627
|
+
created_at=int(generated_at),
|
|
628
|
+
status="in_progress",
|
|
629
|
+
instructions=instructions,
|
|
630
|
+
max_output_tokens=openai_request.max_output_tokens,
|
|
631
|
+
model=openai_request.model,
|
|
632
|
+
output=[],
|
|
633
|
+
output_text="",
|
|
634
|
+
temperature=openai_request.temperature,
|
|
635
|
+
top_p=openai_request.top_p,
|
|
636
|
+
usage={
|
|
637
|
+
"input_tokens": 0, # get prompt tokens
|
|
638
|
+
"output_tokens": 0,
|
|
639
|
+
"total_tokens": 0,
|
|
640
|
+
},
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
# Send response.created event (to match the openai pipeline)
|
|
644
|
+
yield f"event: response.created\ndata: {ResponseCreatedEvent(type='response.created', response=base_response).model_dump_json()}\n\n"
|
|
645
|
+
|
|
646
|
+
# Send response.in_progress event (to match the openai pipeline)
|
|
647
|
+
yield f"event: response.in_progress\ndata: {ResponseInProgressEvent(type='response.in_progress', response=base_response).model_dump_json()}\n\n"
|
|
648
|
+
|
|
649
|
+
# Send response.output_item.added event (to match the openai pipeline)
|
|
650
|
+
message_item = MessageItem(
|
|
651
|
+
id=message_id,
|
|
652
|
+
type="message",
|
|
653
|
+
status="in_progress",
|
|
654
|
+
role="assistant",
|
|
655
|
+
content=[],
|
|
656
|
+
)
|
|
657
|
+
yield f"event: response.output_item.added\ndata: {ResponseOutputItemAddedEvent(type='response.output_item.added', output_index=0, item=message_item).model_dump_json()}\n\n"
|
|
658
|
+
|
|
659
|
+
# Send response.content_part.added event
|
|
660
|
+
content_part = ContentPartOutputText(
|
|
661
|
+
type="output_text", text="", annotations=[]
|
|
662
|
+
)
|
|
663
|
+
yield f"event: response.content_part.added\ndata: {ResponseContentPartAddedEvent(type='response.content_part.added', item_id=message_id, output_index=0, content_index=0, part=content_part).model_dump_json()}\n\n"
|
|
664
|
+
|
|
665
|
+
# Stream text deltas
|
|
666
|
+
token_iterator = stream_generate(
|
|
667
|
+
model=model,
|
|
668
|
+
processor=processor,
|
|
669
|
+
prompt=formatted_prompt,
|
|
670
|
+
image=images,
|
|
671
|
+
temperature=openai_request.temperature,
|
|
672
|
+
max_tokens=openai_request.max_output_tokens,
|
|
673
|
+
top_p=openai_request.top_p,
|
|
674
|
+
**kwargs,
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
full_text = ""
|
|
678
|
+
for chunk in token_iterator:
|
|
679
|
+
if chunk is None or not hasattr(chunk, "text"):
|
|
680
|
+
continue
|
|
681
|
+
|
|
682
|
+
delta = chunk.text
|
|
683
|
+
full_text += delta
|
|
684
|
+
|
|
685
|
+
usage_stats = {
|
|
686
|
+
"input_tokens": chunk.prompt_tokens,
|
|
687
|
+
"output_tokens": chunk.generation_tokens,
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
# Send response.output_text.delta event
|
|
691
|
+
yield f"event: response.output_text.delta\ndata: {ResponseOutputTextDeltaEvent(type='response.output_text.delta', item_id=message_id, output_index=0, content_index=0, delta=delta).model_dump_json()}\n\n"
|
|
692
|
+
await asyncio.sleep(0.01)
|
|
693
|
+
|
|
694
|
+
# Send response.output_text.done event (to match the openai pipeline)
|
|
695
|
+
yield f"event: response.output_text.done\ndata: {ResponseOutputTextDoneEvent(type='response.output_text.done', item_id=message_id, output_index=0, content_index=0, text=full_text).model_dump_json()}\n\n"
|
|
696
|
+
|
|
697
|
+
# Send response.content_part.done event (to match the openai pipeline)
|
|
698
|
+
final_content_part = ContentPartOutputText(
|
|
699
|
+
type="output_text", text=full_text, annotations=[]
|
|
700
|
+
)
|
|
701
|
+
yield f"event: response.content_part.done\ndata: {ResponseContentPartDoneEvent(type='response.content_part.done', item_id=message_id, output_index=0, content_index=0, part=final_content_part).model_dump_json()}\n\n"
|
|
702
|
+
|
|
703
|
+
# Send response.output_item.done event (to match the openai pipeline)
|
|
704
|
+
final_message_item = MessageItem(
|
|
705
|
+
id=message_id,
|
|
706
|
+
type="message",
|
|
707
|
+
status="completed",
|
|
708
|
+
role="assistant",
|
|
709
|
+
content=[final_content_part],
|
|
710
|
+
)
|
|
711
|
+
yield f"event: response.output_item.done\ndata: {ResponseOutputItemDoneEvent(type='response.output_item.done', output_index=0, item=final_message_item).model_dump_json()}\n\n"
|
|
712
|
+
|
|
713
|
+
# Send response.completed event (to match the openai pipeline)
|
|
714
|
+
completed_response = base_response.model_copy(
|
|
715
|
+
update={
|
|
716
|
+
"status": "completed",
|
|
717
|
+
"output": [final_message_item],
|
|
718
|
+
"usage": {
|
|
719
|
+
"input_tokens": usage_stats["input_tokens"],
|
|
720
|
+
"output_tokens": usage_stats["output_tokens"],
|
|
721
|
+
"total_tokens": usage_stats["input_tokens"]
|
|
722
|
+
+ usage_stats["output_tokens"],
|
|
723
|
+
},
|
|
724
|
+
}
|
|
725
|
+
)
|
|
726
|
+
yield f"event: response.completed\ndata: {ResponseCompletedEvent(type='response.completed', response=completed_response).model_dump_json()}\n\n"
|
|
727
|
+
|
|
728
|
+
except Exception as e:
|
|
729
|
+
print(f"Error during stream generation: {e}")
|
|
730
|
+
traceback.print_exc()
|
|
731
|
+
error_data = json.dumps({"error": str(e)})
|
|
732
|
+
yield f"data: {error_data}\n\n"
|
|
733
|
+
|
|
734
|
+
finally:
|
|
735
|
+
mx.clear_cache()
|
|
736
|
+
gc.collect()
|
|
737
|
+
print("Stream finished, cleared cache.")
|
|
738
|
+
|
|
739
|
+
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
|
740
|
+
|
|
741
|
+
else:
|
|
742
|
+
# Non-streaming response
|
|
743
|
+
try:
|
|
744
|
+
# Use generate from generate.py
|
|
745
|
+
result = generate(
|
|
746
|
+
model=model,
|
|
747
|
+
processor=processor,
|
|
748
|
+
prompt=formatted_prompt,
|
|
749
|
+
image=images,
|
|
750
|
+
temperature=openai_request.temperature,
|
|
751
|
+
max_tokens=openai_request.max_output_tokens,
|
|
752
|
+
top_p=openai_request.top_p,
|
|
753
|
+
verbose=False, # stats are passed in the response
|
|
754
|
+
**kwargs,
|
|
755
|
+
)
|
|
756
|
+
# Clean up resources
|
|
757
|
+
mx.clear_cache()
|
|
758
|
+
gc.collect()
|
|
759
|
+
print("Generation finished, cleared cache.")
|
|
760
|
+
|
|
761
|
+
response = OpenAIResponse(
|
|
762
|
+
id=response_id,
|
|
763
|
+
object="response",
|
|
764
|
+
created_at=int(generated_at),
|
|
765
|
+
status="completed",
|
|
766
|
+
instructions=instructions,
|
|
767
|
+
max_output_tokens=openai_request.max_output_tokens,
|
|
768
|
+
model=openai_request.model,
|
|
769
|
+
output=[
|
|
770
|
+
{
|
|
771
|
+
"role": "assistant",
|
|
772
|
+
"content": [
|
|
773
|
+
{
|
|
774
|
+
"type": "output_text",
|
|
775
|
+
"text": result.text,
|
|
776
|
+
}
|
|
777
|
+
],
|
|
778
|
+
}
|
|
779
|
+
],
|
|
780
|
+
output_text=result.text,
|
|
781
|
+
temperature=openai_request.temperature,
|
|
782
|
+
top_p=openai_request.top_p,
|
|
783
|
+
usage={
|
|
784
|
+
"input_tokens": result.prompt_tokens,
|
|
785
|
+
"output_tokens": result.generation_tokens,
|
|
786
|
+
"total_tokens": result.total_tokens,
|
|
787
|
+
},
|
|
788
|
+
)
|
|
789
|
+
return response
|
|
790
|
+
|
|
791
|
+
except Exception as e:
|
|
792
|
+
print(f"Error during generation: {e}")
|
|
793
|
+
traceback.print_exc()
|
|
794
|
+
mx.clear_cache()
|
|
795
|
+
gc.collect()
|
|
796
|
+
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
|
|
797
|
+
|
|
798
|
+
except HTTPException as http_exc:
|
|
799
|
+
# Re-raise HTTP exceptions (like model loading failure)
|
|
800
|
+
raise http_exc
|
|
801
|
+
except Exception as e:
|
|
802
|
+
# Catch unexpected errors
|
|
803
|
+
print(f"Unexpected error in /responses endpoint: {e}")
|
|
804
|
+
traceback.print_exc()
|
|
805
|
+
mx.clear_cache()
|
|
806
|
+
gc.collect()
|
|
807
|
+
raise HTTPException(
|
|
808
|
+
status_code=500, detail=f"An unexpected error occurred: {e}"
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
@app.post(
|
|
813
|
+
"/chat/completions", response_model=None
|
|
814
|
+
) # Response model handled dynamically based on stream flag
|
|
815
|
+
async def chat_completions_endpoint(request: ChatRequest):
|
|
816
|
+
"""
|
|
817
|
+
Generate text based on a prompt and optional images.
|
|
818
|
+
Prompt must be a list of chat messages, including system, user, and assistant messages.
|
|
819
|
+
System message will be ignored if not already in the prompt.
|
|
820
|
+
Can operate in streaming or non-streaming mode.
|
|
821
|
+
"""
|
|
822
|
+
|
|
823
|
+
try:
|
|
824
|
+
# Get model, processor, config - loading if necessary
|
|
825
|
+
model, processor, config = get_cached_model(request.model, request.adapter_path)
|
|
826
|
+
|
|
827
|
+
kwargs = {}
|
|
828
|
+
|
|
829
|
+
if request.resize_shape is not None:
|
|
830
|
+
if len(request.resize_shape) not in [1, 2]:
|
|
831
|
+
raise HTTPException(
|
|
832
|
+
status_code=400,
|
|
833
|
+
detail="resize_shape must contain exactly two integers (height, width)",
|
|
834
|
+
)
|
|
835
|
+
kwargs["resize_shape"] = (
|
|
836
|
+
(request.resize_shape[0],) * 2
|
|
837
|
+
if len(request.resize_shape) == 1
|
|
838
|
+
else tuple(request.resize_shape)
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
chat_messages = request.messages
|
|
842
|
+
|
|
843
|
+
images = []
|
|
844
|
+
audio = []
|
|
845
|
+
processed_messages = []
|
|
846
|
+
for message in request.messages:
|
|
847
|
+
if isinstance(message.content, str):
|
|
848
|
+
processed_messages.append(
|
|
849
|
+
{"role": message.role, "content": message.content}
|
|
850
|
+
)
|
|
851
|
+
elif isinstance(message.content, list):
|
|
852
|
+
text_content = ""
|
|
853
|
+
for item in message.content:
|
|
854
|
+
if isinstance(item, dict):
|
|
855
|
+
# Only extract images/audio from user messages
|
|
856
|
+
if message.role == "user":
|
|
857
|
+
if item["type"] == "input_image":
|
|
858
|
+
images.append(item["image_url"])
|
|
859
|
+
elif item["type"] == "image_url":
|
|
860
|
+
images.append(item["image_url"]["url"])
|
|
861
|
+
elif item["type"] == "input_audio":
|
|
862
|
+
audio.append(item["input_audio"]["data"])
|
|
863
|
+
if item["type"] in ("text", "input_text"):
|
|
864
|
+
text_content = item.get("text", "")
|
|
865
|
+
processed_messages.append(
|
|
866
|
+
{"role": message.role, "content": text_content}
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
formatted_prompt = apply_chat_template(
|
|
870
|
+
processor,
|
|
871
|
+
config,
|
|
872
|
+
processed_messages,
|
|
873
|
+
num_images=len(images),
|
|
874
|
+
num_audios=len(audio),
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
if request.stream:
|
|
878
|
+
# Streaming response
|
|
879
|
+
async def stream_generator():
|
|
880
|
+
token_iterator = None
|
|
881
|
+
try:
|
|
882
|
+
# Use stream_generate from utils
|
|
883
|
+
token_iterator = stream_generate(
|
|
884
|
+
model=model,
|
|
885
|
+
processor=processor,
|
|
886
|
+
prompt=formatted_prompt,
|
|
887
|
+
image=images,
|
|
888
|
+
audio=audio,
|
|
889
|
+
temperature=request.temperature,
|
|
890
|
+
max_tokens=request.max_tokens,
|
|
891
|
+
top_p=request.top_p,
|
|
892
|
+
**kwargs,
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
for chunk in token_iterator:
|
|
896
|
+
if chunk is None or not hasattr(chunk, "text"):
|
|
897
|
+
print("Warning: Received unexpected chunk format:", chunk)
|
|
898
|
+
continue
|
|
899
|
+
|
|
900
|
+
# Yield chunks in Server-Sent Events (SSE) format
|
|
901
|
+
usage_stats = {
|
|
902
|
+
"input_tokens": chunk.prompt_tokens,
|
|
903
|
+
"output_tokens": chunk.generation_tokens,
|
|
904
|
+
"total_tokens": chunk.prompt_tokens
|
|
905
|
+
+ chunk.generation_tokens,
|
|
906
|
+
"prompt_tps": chunk.prompt_tps,
|
|
907
|
+
"generation_tps": chunk.generation_tps,
|
|
908
|
+
"peak_memory": chunk.peak_memory,
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
choices = [
|
|
912
|
+
ChatStreamChoice(
|
|
913
|
+
delta=ChatMessage(role="assistant", content=chunk.text)
|
|
914
|
+
)
|
|
915
|
+
]
|
|
916
|
+
chunk_data = ChatStreamChunk(
|
|
917
|
+
model=request.model, usage=usage_stats, choices=choices
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
yield f"data: {chunk_data.model_dump_json()}\n\n"
|
|
921
|
+
await asyncio.sleep(
|
|
922
|
+
0.01
|
|
923
|
+
) # Small sleep to prevent blocking event loop entirely
|
|
924
|
+
|
|
925
|
+
# Signal stream end
|
|
926
|
+
choices = [
|
|
927
|
+
ChatStreamChoice(
|
|
928
|
+
finish_reason="stop",
|
|
929
|
+
delta=ChatMessage(role="assistant", content=""),
|
|
930
|
+
)
|
|
931
|
+
]
|
|
932
|
+
chunk_data = ChatStreamChunk(
|
|
933
|
+
model=request.model, usage=usage_stats, choices=choices
|
|
934
|
+
)
|
|
935
|
+
yield f"data: {chunk_data.model_dump_json()}\n\n"
|
|
936
|
+
|
|
937
|
+
except Exception as e:
|
|
938
|
+
print(f"Error during stream generation: {e}")
|
|
939
|
+
traceback.print_exc()
|
|
940
|
+
error_data = json.dumps({"error": str(e)})
|
|
941
|
+
yield f"data: {error_data}\n\n"
|
|
942
|
+
|
|
943
|
+
finally:
|
|
944
|
+
mx.clear_cache()
|
|
945
|
+
gc.collect()
|
|
946
|
+
print("Stream finished, cleared cache.")
|
|
947
|
+
|
|
948
|
+
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
|
949
|
+
|
|
950
|
+
else:
|
|
951
|
+
# Non-streaming response
|
|
952
|
+
try:
|
|
953
|
+
# Use generate from generate.py
|
|
954
|
+
gen_result = generate(
|
|
955
|
+
model=model,
|
|
956
|
+
processor=processor,
|
|
957
|
+
prompt=formatted_prompt,
|
|
958
|
+
image=images,
|
|
959
|
+
audio=audio,
|
|
960
|
+
temperature=request.temperature,
|
|
961
|
+
max_tokens=request.max_tokens,
|
|
962
|
+
top_p=request.top_p,
|
|
963
|
+
verbose=False, # Keep API output clean
|
|
964
|
+
**kwargs,
|
|
965
|
+
)
|
|
966
|
+
# Clean up resources
|
|
967
|
+
mx.clear_cache()
|
|
968
|
+
gc.collect()
|
|
969
|
+
print("Generation finished, cleared cache.")
|
|
970
|
+
|
|
971
|
+
usage_stats = UsageStats(
|
|
972
|
+
input_tokens=gen_result.prompt_tokens,
|
|
973
|
+
output_tokens=gen_result.generation_tokens,
|
|
974
|
+
total_tokens=gen_result.total_tokens,
|
|
975
|
+
prompt_tps=gen_result.prompt_tps,
|
|
976
|
+
generation_tps=gen_result.generation_tps,
|
|
977
|
+
peak_memory=gen_result.peak_memory,
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
choices = [
|
|
981
|
+
ChatChoice(
|
|
982
|
+
finish_reason="stop",
|
|
983
|
+
message=ChatMessage(role="assistant", content=gen_result.text),
|
|
984
|
+
)
|
|
985
|
+
]
|
|
986
|
+
result = ChatResponse(
|
|
987
|
+
model=request.model, usage=usage_stats, choices=choices
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
return result
|
|
991
|
+
|
|
992
|
+
except Exception as e:
|
|
993
|
+
print(f"Error during generation: {e}")
|
|
994
|
+
traceback.print_exc()
|
|
995
|
+
mx.clear_cache()
|
|
996
|
+
gc.collect()
|
|
997
|
+
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
|
|
998
|
+
|
|
999
|
+
except HTTPException as http_exc:
|
|
1000
|
+
# Re-raise HTTP exceptions (like model loading failure)
|
|
1001
|
+
raise http_exc
|
|
1002
|
+
except Exception as e:
|
|
1003
|
+
# Catch unexpected errors
|
|
1004
|
+
print(f"Unexpected error in /generate endpoint: {e}")
|
|
1005
|
+
traceback.print_exc()
|
|
1006
|
+
mx.clear_cache()
|
|
1007
|
+
gc.collect()
|
|
1008
|
+
raise HTTPException(
|
|
1009
|
+
status_code=500, detail=f"An unexpected error occurred: {e}"
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
@app.get("/models", response_model=ModelsResponse)
|
|
1014
|
+
def models_endpoint():
|
|
1015
|
+
"""
|
|
1016
|
+
Return list of locally downloaded MLX models.
|
|
1017
|
+
"""
|
|
1018
|
+
|
|
1019
|
+
files = ["config.json", "model.safetensors.index.json", "tokenizer_config.json"]
|
|
1020
|
+
|
|
1021
|
+
def probably_mlx_lm(repo):
|
|
1022
|
+
if repo.repo_type != "model":
|
|
1023
|
+
return False
|
|
1024
|
+
if "main" not in repo.refs:
|
|
1025
|
+
return False
|
|
1026
|
+
file_names = {f.file_path.name for f in repo.refs["main"].files}
|
|
1027
|
+
return all(f in file_names for f in files)
|
|
1028
|
+
|
|
1029
|
+
# Scan the cache directory for downloaded mlx models
|
|
1030
|
+
hf_cache_info = scan_cache_dir()
|
|
1031
|
+
downloaded_models = [repo for repo in hf_cache_info.repos if probably_mlx_lm(repo)]
|
|
1032
|
+
|
|
1033
|
+
# Create a list of available models
|
|
1034
|
+
models = [
|
|
1035
|
+
{"id": repo.repo_id, "object": "model", "created": int(repo.last_modified)}
|
|
1036
|
+
for repo in downloaded_models
|
|
1037
|
+
]
|
|
1038
|
+
|
|
1039
|
+
response = {"object": "list", "data": models}
|
|
1040
|
+
|
|
1041
|
+
return response
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
# MLX_VLM API endpoints
|
|
1045
|
+
|
|
1046
|
+
|
|
1047
|
+
@app.get("/health")
|
|
1048
|
+
async def health_check():
|
|
1049
|
+
"""
|
|
1050
|
+
Check if the server is healthy and what model is loaded.
|
|
1051
|
+
"""
|
|
1052
|
+
return {
|
|
1053
|
+
"status": "healthy",
|
|
1054
|
+
"loaded_model": model_cache.get("model_path", None),
|
|
1055
|
+
"loaded_adapter": model_cache.get("adapter_path", None),
|
|
1056
|
+
}
|
|
1057
|
+
|
|
1058
|
+
|
|
1059
|
+
@app.post("/unload")
|
|
1060
|
+
async def unload_model_endpoint():
|
|
1061
|
+
"""
|
|
1062
|
+
Unload the currently loaded model from memory.
|
|
1063
|
+
"""
|
|
1064
|
+
unloaded_info = {
|
|
1065
|
+
"model_name": model_cache.get("model_path", None),
|
|
1066
|
+
"adapter_name": model_cache.get("adapter_path", None),
|
|
1067
|
+
}
|
|
1068
|
+
|
|
1069
|
+
if not unload_model_sync(): # Use the synchronous unload function
|
|
1070
|
+
return {"status": "no_model_loaded", "message": "No model is currently loaded"}
|
|
1071
|
+
|
|
1072
|
+
return {
|
|
1073
|
+
"status": "success",
|
|
1074
|
+
"message": f"Model unloaded successfully",
|
|
1075
|
+
"unloaded": unloaded_info,
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
|
|
1079
|
+
def main():
|
|
1080
|
+
parser = argparse.ArgumentParser(description="MLX VLM Http Server.")
|
|
1081
|
+
parser.add_argument(
|
|
1082
|
+
"--host",
|
|
1083
|
+
type=str,
|
|
1084
|
+
default="0.0.0.0",
|
|
1085
|
+
help="Host for the HTTP server (default:0.0.0.0)",
|
|
1086
|
+
)
|
|
1087
|
+
parser.add_argument(
|
|
1088
|
+
"--port",
|
|
1089
|
+
type=int,
|
|
1090
|
+
default=8080,
|
|
1091
|
+
help="Port for the HTTP server (default: 8080)",
|
|
1092
|
+
)
|
|
1093
|
+
parser.add_argument(
|
|
1094
|
+
"--trust-remote-code",
|
|
1095
|
+
action="store_true",
|
|
1096
|
+
help="Trust remote code when loading models from Hugging Face Hub.",
|
|
1097
|
+
)
|
|
1098
|
+
args = parser.parse_args()
|
|
1099
|
+
if args.trust_remote_code:
|
|
1100
|
+
os.environ["MLX_TRUST_REMOTE_CODE"] = "true"
|
|
1101
|
+
uvicorn.run(
|
|
1102
|
+
"mlx_vlm.server:app", host=args.host, port=args.port, workers=1, reload=True
|
|
1103
|
+
) # reload=True for development to automatically restart on code changes.
|
|
1104
|
+
|
|
1105
|
+
|
|
1106
|
+
if __name__ == "__main__":
|
|
1107
|
+
main()
|