@modular-prompt/driver 0.12.0 → 0.13.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/anthropic/anthropic-driver.d.ts +38 -8
- package/dist/anthropic/anthropic-driver.d.ts.map +1 -1
- package/dist/anthropic/anthropic-driver.js +180 -164
- package/dist/anthropic/anthropic-driver.js.map +1 -1
- package/dist/cache-controller.d.ts +28 -0
- package/dist/cache-controller.d.ts.map +1 -0
- package/dist/cache-controller.js +2 -0
- package/dist/cache-controller.js.map +1 -0
- package/dist/cache-utils.d.ts +20 -0
- package/dist/cache-utils.d.ts.map +1 -0
- package/dist/cache-utils.js +71 -0
- package/dist/cache-utils.js.map +1 -0
- package/dist/content-utils.d.ts.map +1 -1
- package/dist/content-utils.js +20 -0
- package/dist/content-utils.js.map +1 -1
- package/dist/driver-registry/config-based-factory.d.ts.map +1 -1
- package/dist/driver-registry/config-based-factory.js +7 -0
- package/dist/driver-registry/config-based-factory.js.map +1 -1
- package/dist/driver-registry/factory-helper.d.ts.map +1 -1
- package/dist/driver-registry/factory-helper.js +7 -4
- package/dist/driver-registry/factory-helper.js.map +1 -1
- package/dist/driver-registry/types.d.ts +6 -0
- package/dist/driver-registry/types.d.ts.map +1 -1
- package/dist/formatter/converter.js +1 -1
- package/dist/formatter/converter.js.map +1 -1
- package/dist/google-genai/element-converter.d.ts +11 -0
- package/dist/google-genai/element-converter.d.ts.map +1 -0
- package/dist/google-genai/element-converter.js +126 -0
- package/dist/google-genai/element-converter.js.map +1 -0
- package/dist/google-genai/google-genai-cache-controller.d.ts +24 -0
- package/dist/google-genai/google-genai-cache-controller.d.ts.map +1 -0
- package/dist/google-genai/google-genai-cache-controller.js +127 -0
- package/dist/google-genai/google-genai-cache-controller.js.map +1 -0
- package/dist/google-genai/google-genai-driver.d.ts +5 -29
- package/dist/google-genai/google-genai-driver.d.ts.map +1 -1
- package/dist/google-genai/google-genai-driver.js +92 -255
- package/dist/google-genai/google-genai-driver.js.map +1 -1
- package/dist/index.d.ts +4 -0
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +3 -0
- package/dist/index.js.map +1 -1
- package/dist/mlx-ml/mlx-cache-controller.d.ts +66 -0
- package/dist/mlx-ml/mlx-cache-controller.d.ts.map +1 -0
- package/dist/mlx-ml/mlx-cache-controller.js +600 -0
- package/dist/mlx-ml/mlx-cache-controller.js.map +1 -0
- package/dist/mlx-ml/mlx-driver.d.ts +12 -7
- package/dist/mlx-ml/mlx-driver.d.ts.map +1 -1
- package/dist/mlx-ml/mlx-driver.js +192 -124
- package/dist/mlx-ml/mlx-driver.js.map +1 -1
- package/dist/mlx-ml/mlx-message-utils.d.ts +9 -0
- package/dist/mlx-ml/mlx-message-utils.d.ts.map +1 -0
- package/dist/mlx-ml/mlx-message-utils.js +71 -0
- package/dist/mlx-ml/mlx-message-utils.js.map +1 -0
- package/dist/mlx-ml/process/index.d.ts +7 -3
- package/dist/mlx-ml/process/index.d.ts.map +1 -1
- package/dist/mlx-ml/process/index.js +22 -7
- package/dist/mlx-ml/process/index.js.map +1 -1
- package/dist/mlx-ml/process/model-handlers.d.ts +4 -59
- package/dist/mlx-ml/process/model-handlers.d.ts.map +1 -1
- package/dist/mlx-ml/process/model-handlers.js +15 -14
- package/dist/mlx-ml/process/model-handlers.js.map +1 -1
- package/dist/mlx-ml/process/model-specific.d.ts +7 -0
- package/dist/mlx-ml/process/model-specific.d.ts.map +1 -1
- package/dist/mlx-ml/process/model-specific.js +3 -0
- package/dist/mlx-ml/process/model-specific.js.map +1 -1
- package/dist/mlx-ml/process/process-communication.d.ts +3 -0
- package/dist/mlx-ml/process/process-communication.d.ts.map +1 -1
- package/dist/mlx-ml/process/process-communication.js +13 -0
- package/dist/mlx-ml/process/process-communication.js.map +1 -1
- package/dist/mlx-ml/process/queue.d.ts +5 -2
- package/dist/mlx-ml/process/queue.d.ts.map +1 -1
- package/dist/mlx-ml/process/queue.js +101 -14
- package/dist/mlx-ml/process/queue.js.map +1 -1
- package/dist/mlx-ml/process/response-processor.d.ts +10 -0
- package/dist/mlx-ml/process/response-processor.d.ts.map +1 -1
- package/dist/mlx-ml/process/response-processor.js +23 -1
- package/dist/mlx-ml/process/response-processor.js.map +1 -1
- package/dist/mlx-ml/process/types.d.ts +50 -4
- package/dist/mlx-ml/process/types.d.ts.map +1 -1
- package/dist/mlx-ml/tool-call-parser.d.ts.map +1 -1
- package/dist/mlx-ml/tool-call-parser.js +44 -25
- package/dist/mlx-ml/tool-call-parser.js.map +1 -1
- package/dist/types.d.ts +2 -0
- package/dist/types.d.ts.map +1 -1
- package/package.json +7 -4
- package/src/mlx-ml/python/__main__.py +41 -449
- package/src/mlx-ml/python/backends/__init__.py +3 -0
- package/src/mlx-ml/python/backends/base.py +84 -0
- package/src/mlx-ml/python/backends/mlx_lm.py +202 -0
- package/src/mlx-ml/python/backends/mlx_vlm.py +99 -0
- package/src/mlx-ml/python/handlers/__init__.py +6 -0
- package/src/mlx-ml/python/handlers/cache.py +81 -0
- package/src/mlx-ml/python/handlers/capabilities.py +6 -0
- package/src/mlx-ml/python/handlers/chat.py +221 -0
- package/src/mlx-ml/python/handlers/completion.py +36 -0
- package/src/mlx-ml/python/handlers/format_test.py +70 -0
- package/src/mlx-ml/python/handlers/tokenize.py +63 -0
- package/src/mlx-ml/python/pyproject.toml +13 -3
- package/src/mlx-ml/python/server.py +126 -0
- package/src/mlx-ml/python/tests/__init__.py +0 -0
- package/src/mlx-ml/python/utils/__init__.py +0 -0
- package/src/mlx-ml/python/utils/prompt_builder.py +54 -0
- package/src/mlx-ml/python/{token_utils.py → utils/token_utils.py} +1 -2
- package/src/mlx-ml/python/uv.lock +266 -41
- /package/src/mlx-ml/python/{example_basic.py → examples/example_basic.py} +0 -0
- /package/src/mlx-ml/python/{example_tool_call.py → examples/example_tool_call.py} +0 -0
- /package/src/mlx-ml/python/{chat_template_constraints.py → utils/chat_template_constraints.py} +0 -0
- /package/src/mlx-ml/python/{vlm_utils.py → utils/vlm_utils.py} +0 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any, Iterator
|
|
7
|
+
|
|
8
|
+
from mlx_lm import load as mlx_lm_load
|
|
9
|
+
from mlx_lm import stream_generate as mlx_lm_stream_generate
|
|
10
|
+
from mlx_lm.models.cache import make_prompt_cache, save_prompt_cache, load_prompt_cache, trim_prompt_cache
|
|
11
|
+
from mlx_lm.sample_utils import make_sampler
|
|
12
|
+
|
|
13
|
+
from backends.base import ModelBackend
|
|
14
|
+
from utils.token_utils import is_eod_token
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MlxLmBackend(ModelBackend):
|
|
18
|
+
"""`mlx_lm` backend for text-only models."""
|
|
19
|
+
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
self.model: Any | None = None
|
|
22
|
+
self.tokenizer: Any | None = None
|
|
23
|
+
|
|
24
|
+
def load(self, model_name: str) -> None:
|
|
25
|
+
self.model, self.tokenizer = mlx_lm_load(model_name)
|
|
26
|
+
|
|
27
|
+
def get_tokenizer(self) -> Any:
|
|
28
|
+
return self.tokenizer
|
|
29
|
+
|
|
30
|
+
def stream_generate(
|
|
31
|
+
self,
|
|
32
|
+
prompt: str | list[int],
|
|
33
|
+
options: dict,
|
|
34
|
+
images: list | None = None,
|
|
35
|
+
prompt_cache: list | None = None,
|
|
36
|
+
) -> Iterator[Any]:
|
|
37
|
+
if self.model is None or self.tokenizer is None:
|
|
38
|
+
raise RuntimeError("Model is not loaded")
|
|
39
|
+
|
|
40
|
+
final_options = {"max_tokens": 1000, **options}
|
|
41
|
+
temperature = final_options.pop("temperature", 1.0)
|
|
42
|
+
top_p = final_options.pop("top_p", 0.0)
|
|
43
|
+
top_k = final_options.pop("top_k", 0)
|
|
44
|
+
final_options["sampler"] = make_sampler(
|
|
45
|
+
temp=temperature,
|
|
46
|
+
top_p=top_p,
|
|
47
|
+
top_k=top_k,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if prompt_cache is not None:
|
|
51
|
+
final_options["prompt_cache"] = prompt_cache
|
|
52
|
+
|
|
53
|
+
for response in mlx_lm_stream_generate(
|
|
54
|
+
self.model,
|
|
55
|
+
self.tokenizer,
|
|
56
|
+
prompt,
|
|
57
|
+
**final_options,
|
|
58
|
+
):
|
|
59
|
+
if is_eod_token(response, self.tokenizer):
|
|
60
|
+
break
|
|
61
|
+
yield response
|
|
62
|
+
|
|
63
|
+
# get_cache_offset is inherited from ModelBackend base class
|
|
64
|
+
|
|
65
|
+
def _tokenize_prompt(self, prompt: str) -> list[int]:
|
|
66
|
+
"""Tokenize a prompt string using the same logic as stream_generate."""
|
|
67
|
+
add_special = self.tokenizer.bos_token is None or not prompt.startswith(
|
|
68
|
+
self.tokenizer.bos_token
|
|
69
|
+
)
|
|
70
|
+
return self.tokenizer.encode(prompt, add_special_tokens=add_special)
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def _write_cache_meta(
|
|
74
|
+
cache_path: str,
|
|
75
|
+
token_count: int,
|
|
76
|
+
prefix_offsets: list[int] | None = None,
|
|
77
|
+
prefix_hashes: list[str] | None = None,
|
|
78
|
+
) -> None:
|
|
79
|
+
meta_path = cache_path + '.meta.json'
|
|
80
|
+
try:
|
|
81
|
+
meta: dict[str, Any] = {"token_count": token_count}
|
|
82
|
+
if prefix_offsets and prefix_hashes:
|
|
83
|
+
meta["prefix_offsets"] = prefix_offsets
|
|
84
|
+
meta["prefix_hashes"] = prefix_hashes
|
|
85
|
+
with open(meta_path, 'w') as f:
|
|
86
|
+
json.dump(meta, f)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
sys.stderr.write(f"Failed to write cache meta: {e}\n")
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def _read_cache_meta(cache_path: str) -> int | None:
|
|
92
|
+
meta_path = cache_path + '.meta.json'
|
|
93
|
+
try:
|
|
94
|
+
with open(meta_path) as f:
|
|
95
|
+
meta = json.load(f)
|
|
96
|
+
count = meta.get('token_count')
|
|
97
|
+
return int(count) if count is not None else None
|
|
98
|
+
except (FileNotFoundError, json.JSONDecodeError, ValueError, TypeError):
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
def cache_prefill(
|
|
102
|
+
self,
|
|
103
|
+
cache_path: str,
|
|
104
|
+
prompt: str,
|
|
105
|
+
base_cache_path: str | None = None,
|
|
106
|
+
trim_to_tokens: int | None = None,
|
|
107
|
+
prefix_offsets: list[int] | None = None,
|
|
108
|
+
prefix_hashes: list[str] | None = None,
|
|
109
|
+
) -> dict:
|
|
110
|
+
if self.model is None or self.tokenizer is None:
|
|
111
|
+
raise RuntimeError("Model is not loaded")
|
|
112
|
+
|
|
113
|
+
full_tokens = self._tokenize_prompt(prompt)
|
|
114
|
+
token_count = len(full_tokens)
|
|
115
|
+
effective_prompt: str | list[int] = prompt
|
|
116
|
+
|
|
117
|
+
if base_cache_path is not None:
|
|
118
|
+
try:
|
|
119
|
+
prompt_cache = load_prompt_cache(base_cache_path)
|
|
120
|
+
|
|
121
|
+
if trim_to_tokens is not None:
|
|
122
|
+
current_offset = self.get_cache_offset(prompt_cache)
|
|
123
|
+
if current_offset > trim_to_tokens:
|
|
124
|
+
trim_count = current_offset - trim_to_tokens
|
|
125
|
+
trim_prompt_cache(prompt_cache, trim_count)
|
|
126
|
+
if os.getenv('MLX_DEBUG'):
|
|
127
|
+
sys.stderr.write(
|
|
128
|
+
f"Trimmed base cache: {current_offset} → {trim_to_tokens} tokens\n"
|
|
129
|
+
)
|
|
130
|
+
cache_offset = trim_to_tokens
|
|
131
|
+
else:
|
|
132
|
+
cache_offset = current_offset
|
|
133
|
+
else:
|
|
134
|
+
meta_offset = self._read_cache_meta(base_cache_path)
|
|
135
|
+
if meta_offset is None:
|
|
136
|
+
# Legacy cache without meta file - create fresh cache
|
|
137
|
+
sys.stderr.write(
|
|
138
|
+
f"WARNING: Cache file exists but no .meta.json found at {base_cache_path}. "
|
|
139
|
+
"Creating fresh cache (may be from old implementation).\n"
|
|
140
|
+
)
|
|
141
|
+
cache_offset = 0
|
|
142
|
+
prompt_cache = make_prompt_cache(self.model)
|
|
143
|
+
else:
|
|
144
|
+
cache_offset = meta_offset
|
|
145
|
+
|
|
146
|
+
if cache_offset > 0 and prompt_cache is not None:
|
|
147
|
+
if cache_offset < token_count:
|
|
148
|
+
effective_prompt = full_tokens[cache_offset:]
|
|
149
|
+
if os.getenv('MLX_DEBUG'):
|
|
150
|
+
sys.stderr.write(
|
|
151
|
+
f"Incremental prefill from: {base_cache_path} "
|
|
152
|
+
f"(skip {cache_offset}/{token_count} tokens)\n"
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
if os.getenv('MLX_DEBUG'):
|
|
156
|
+
sys.stderr.write(
|
|
157
|
+
f"Base cache covers entire prompt "
|
|
158
|
+
f"({cache_offset} >= {token_count}), saving as-is\n"
|
|
159
|
+
)
|
|
160
|
+
save_prompt_cache(cache_path, prompt_cache)
|
|
161
|
+
self._write_cache_meta(cache_path, token_count, prefix_offsets, prefix_hashes)
|
|
162
|
+
return {"cache_path": cache_path, "token_count": token_count}
|
|
163
|
+
else:
|
|
164
|
+
if os.getenv('MLX_DEBUG'):
|
|
165
|
+
sys.stderr.write(f"Incremental prefill from: {base_cache_path}\n")
|
|
166
|
+
except Exception as e:
|
|
167
|
+
sys.stderr.write(f"Base cache load failed, creating fresh: {e}\n")
|
|
168
|
+
prompt_cache = make_prompt_cache(self.model)
|
|
169
|
+
else:
|
|
170
|
+
prompt_cache = make_prompt_cache(self.model)
|
|
171
|
+
|
|
172
|
+
if os.getenv('MLX_DEBUG'):
|
|
173
|
+
sys.stderr.write(f"Prefill prompt: {token_count} tokens\n")
|
|
174
|
+
|
|
175
|
+
for _ in mlx_lm_stream_generate(
|
|
176
|
+
self.model, self.tokenizer, effective_prompt,
|
|
177
|
+
prompt_cache=prompt_cache, max_tokens=1,
|
|
178
|
+
):
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
save_prompt_cache(cache_path, prompt_cache)
|
|
182
|
+
self._write_cache_meta(cache_path, token_count, prefix_offsets, prefix_hashes)
|
|
183
|
+
if os.getenv('MLX_DEBUG'):
|
|
184
|
+
sys.stderr.write(f"Cache created: {cache_path} ({token_count} tokens)\n")
|
|
185
|
+
return {"cache_path": cache_path, "token_count": token_count}
|
|
186
|
+
|
|
187
|
+
def load_cache_from_file(self, cache_path: str) -> list | None:
|
|
188
|
+
try:
|
|
189
|
+
return load_prompt_cache(cache_path)
|
|
190
|
+
except FileNotFoundError:
|
|
191
|
+
sys.stderr.write(f"Cache file not found: {cache_path}\n")
|
|
192
|
+
return None
|
|
193
|
+
except Exception as e:
|
|
194
|
+
sys.stderr.write(f"Failed to load cache: {e}\n")
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
def supports_vision(self) -> bool:
|
|
198
|
+
return False
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def model_kind(self) -> str:
|
|
202
|
+
return "lm"
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Any, Iterator
|
|
5
|
+
|
|
6
|
+
from mlx_vlm import load as mlx_vlm_load
|
|
7
|
+
from mlx_vlm import stream_generate as mlx_vlm_stream_generate
|
|
8
|
+
|
|
9
|
+
from backends.base import ModelBackend
|
|
10
|
+
from utils.vlm_utils import load_and_resize_images
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MlxVlmBackend(ModelBackend):
|
|
14
|
+
"""`mlx_vlm` backend for vision-language models."""
|
|
15
|
+
|
|
16
|
+
def __init__(self) -> None:
|
|
17
|
+
self.model: Any | None = None
|
|
18
|
+
self.processor: Any | None = None
|
|
19
|
+
self.drafter: Any | None = None
|
|
20
|
+
self.drafter_kind: str | None = None
|
|
21
|
+
self.draft_block_size: int | None = None
|
|
22
|
+
|
|
23
|
+
def load(self, model_name: str) -> None:
|
|
24
|
+
self.model, self.processor = mlx_vlm_load(model_name)
|
|
25
|
+
|
|
26
|
+
def load_drafter(self, drafter_model: str) -> None:
|
|
27
|
+
from mlx_vlm.speculative.drafters import load_drafter
|
|
28
|
+
self.drafter, self.drafter_kind = load_drafter(drafter_model)
|
|
29
|
+
sys.stderr.write(f"Drafter loaded: {drafter_model} (kind={self.drafter_kind})\n")
|
|
30
|
+
|
|
31
|
+
def has_drafter(self) -> bool:
|
|
32
|
+
return self.drafter is not None
|
|
33
|
+
|
|
34
|
+
def get_tokenizer(self) -> Any:
|
|
35
|
+
return self.processor
|
|
36
|
+
|
|
37
|
+
def stream_generate(
|
|
38
|
+
self, prompt: str | list[int], options: dict, images: list | None = None,
|
|
39
|
+
prompt_cache: list | None = None,
|
|
40
|
+
) -> Iterator[Any]:
|
|
41
|
+
if self.model is None or self.processor is None:
|
|
42
|
+
raise RuntimeError("Model is not loaded")
|
|
43
|
+
|
|
44
|
+
final_options = dict(options)
|
|
45
|
+
temperature = final_options.pop("temperature", 1.0)
|
|
46
|
+
max_tokens = final_options.pop("max_tokens", 1000)
|
|
47
|
+
top_p = final_options.pop("top_p", 0.0)
|
|
48
|
+
top_k = final_options.pop("top_k", 0)
|
|
49
|
+
|
|
50
|
+
processed_images = None
|
|
51
|
+
if images:
|
|
52
|
+
max_image_size = final_options.pop("max_image_size", 768)
|
|
53
|
+
processed_images = load_and_resize_images(images, max_image_size)
|
|
54
|
+
|
|
55
|
+
draft_kwargs = {}
|
|
56
|
+
if self.drafter:
|
|
57
|
+
draft_kwargs["draft_model"] = self.drafter
|
|
58
|
+
draft_kwargs["draft_kind"] = self.drafter_kind
|
|
59
|
+
if self.draft_block_size is not None:
|
|
60
|
+
draft_kwargs["draft_block_size"] = self.draft_block_size
|
|
61
|
+
|
|
62
|
+
# Generate and collect tokens
|
|
63
|
+
token_count = 0
|
|
64
|
+
for result in mlx_vlm_stream_generate(
|
|
65
|
+
self.model,
|
|
66
|
+
self.processor,
|
|
67
|
+
prompt,
|
|
68
|
+
image=processed_images,
|
|
69
|
+
max_tokens=max_tokens,
|
|
70
|
+
temperature=temperature,
|
|
71
|
+
top_p=top_p,
|
|
72
|
+
top_k=top_k,
|
|
73
|
+
**draft_kwargs,
|
|
74
|
+
):
|
|
75
|
+
token_count += 1
|
|
76
|
+
yield result
|
|
77
|
+
|
|
78
|
+
# Output speculative decoding stats if drafter was used
|
|
79
|
+
if self.drafter and hasattr(self.drafter, 'accept_lens'):
|
|
80
|
+
accept_lens = self.drafter.accept_lens
|
|
81
|
+
if accept_lens:
|
|
82
|
+
avg_accepted = sum(accept_lens) / len(accept_lens)
|
|
83
|
+
# Show stats unless MLX_NO_STATS environment variable is set
|
|
84
|
+
import os
|
|
85
|
+
if not os.getenv('MLX_NO_STATS'):
|
|
86
|
+
sys.stderr.write(f"\n[Speculative Decoding Stats]\n")
|
|
87
|
+
sys.stderr.write(f" Rounds: {len(accept_lens)}\n")
|
|
88
|
+
sys.stderr.write(f" Average accepted tokens/round: {avg_accepted:.2f}\n")
|
|
89
|
+
sys.stderr.write(f" Total tokens generated: {token_count}\n")
|
|
90
|
+
sys.stderr.write(f" Speedup factor: {avg_accepted:.2f}x (theoretical)\n")
|
|
91
|
+
# Clear for next generation
|
|
92
|
+
self.drafter.accept_lens = []
|
|
93
|
+
|
|
94
|
+
def supports_vision(self) -> bool:
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def model_kind(self) -> str:
|
|
99
|
+
return "vlm"
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from handlers.cache import handle_cache_prefill
|
|
2
|
+
from handlers.capabilities import handle_capabilities
|
|
3
|
+
from handlers.chat import handle_chat
|
|
4
|
+
from handlers.completion import handle_completion
|
|
5
|
+
from handlers.format_test import handle_format_test
|
|
6
|
+
from handlers.tokenize import handle_tokenize
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
from backends.base import ModelBackend
|
|
7
|
+
from utils.prompt_builder import generate_merged_prompt, supports_chat_template
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def handle_cache_prefill(
|
|
11
|
+
backend: ModelBackend,
|
|
12
|
+
capabilities: dict,
|
|
13
|
+
cache_path: str,
|
|
14
|
+
messages: list,
|
|
15
|
+
base_cache_path: str | None = None,
|
|
16
|
+
trim_to_tokens: int | None = None,
|
|
17
|
+
prefix_offsets: list[int] | None = None,
|
|
18
|
+
prefix_hashes: list[str] | None = None,
|
|
19
|
+
tools: list | None = None,
|
|
20
|
+
reasoning_effort: str | None = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
tokenizer = backend.get_tokenizer()
|
|
23
|
+
|
|
24
|
+
extra_kwargs = {}
|
|
25
|
+
if tools is not None:
|
|
26
|
+
extra_kwargs["tools"] = tools
|
|
27
|
+
if reasoning_effort is not None:
|
|
28
|
+
extra_kwargs["reasoning_effort"] = reasoning_effort
|
|
29
|
+
if supports_chat_template(tokenizer):
|
|
30
|
+
try:
|
|
31
|
+
prompt = tokenizer.apply_chat_template(
|
|
32
|
+
messages,
|
|
33
|
+
add_generation_prompt=False,
|
|
34
|
+
tokenize=False,
|
|
35
|
+
**extra_kwargs,
|
|
36
|
+
)
|
|
37
|
+
except TypeError:
|
|
38
|
+
try:
|
|
39
|
+
fallback_kwargs = {}
|
|
40
|
+
if tools is not None:
|
|
41
|
+
fallback_kwargs["tools"] = tools
|
|
42
|
+
prompt = tokenizer.apply_chat_template(
|
|
43
|
+
messages,
|
|
44
|
+
add_generation_prompt=False,
|
|
45
|
+
tokenize=False,
|
|
46
|
+
**fallback_kwargs,
|
|
47
|
+
)
|
|
48
|
+
except TypeError:
|
|
49
|
+
try:
|
|
50
|
+
prompt = tokenizer.apply_chat_template(
|
|
51
|
+
messages,
|
|
52
|
+
add_generation_prompt=False,
|
|
53
|
+
tokenize=False,
|
|
54
|
+
)
|
|
55
|
+
except Exception:
|
|
56
|
+
prompt = generate_merged_prompt(messages, capabilities)
|
|
57
|
+
sys.stderr.write(
|
|
58
|
+
"--- cache_prefill: fallback to generate_merged_prompt\n"
|
|
59
|
+
)
|
|
60
|
+
except Exception:
|
|
61
|
+
prompt = generate_merged_prompt(messages, capabilities)
|
|
62
|
+
sys.stderr.write(
|
|
63
|
+
"--- cache_prefill: fallback to generate_merged_prompt\n"
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
prompt = generate_merged_prompt(messages, capabilities)
|
|
67
|
+
|
|
68
|
+
# Only show debug output if MLX_DEBUG environment variable is set
|
|
69
|
+
import os
|
|
70
|
+
if os.getenv('MLX_DEBUG'):
|
|
71
|
+
sys.stderr.write(f"--- cache_prefill {cache_path}\n")
|
|
72
|
+
result = backend.cache_prefill(
|
|
73
|
+
cache_path, prompt, base_cache_path,
|
|
74
|
+
trim_to_tokens=trim_to_tokens,
|
|
75
|
+
prefix_offsets=prefix_offsets,
|
|
76
|
+
prefix_hashes=prefix_hashes,
|
|
77
|
+
)
|
|
78
|
+
if prefix_offsets and prefix_hashes:
|
|
79
|
+
result["prefix_offsets"] = prefix_offsets
|
|
80
|
+
result["prefix_hashes"] = prefix_hashes
|
|
81
|
+
print(json.dumps(result), end="\0", flush=True)
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
from backends.base import ModelBackend
|
|
8
|
+
from mlx_lm.models.cache import trim_prompt_cache
|
|
9
|
+
from utils.prompt_builder import generate_merged_prompt, supports_chat_template
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _read_cache_token_count(cache_path: str) -> int | None:
|
|
13
|
+
"""Read token count from the sidecar .meta.json file."""
|
|
14
|
+
meta_path = cache_path + '.meta.json'
|
|
15
|
+
try:
|
|
16
|
+
with open(meta_path) as f:
|
|
17
|
+
meta = json.load(f)
|
|
18
|
+
count = meta.get('token_count')
|
|
19
|
+
return int(count) if count is not None else None
|
|
20
|
+
except (FileNotFoundError, json.JSONDecodeError, ValueError, TypeError):
|
|
21
|
+
return None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _stream_to_stdout(
|
|
25
|
+
backend: ModelBackend,
|
|
26
|
+
prompt: str | list[int],
|
|
27
|
+
options: dict,
|
|
28
|
+
images: list | None = None,
|
|
29
|
+
primer: str | None = None,
|
|
30
|
+
prompt_cache: list | None = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
if primer is not None:
|
|
33
|
+
print(primer, end="", flush=True)
|
|
34
|
+
|
|
35
|
+
last_response = None
|
|
36
|
+
for response in backend.stream_generate(prompt, options, images, prompt_cache=prompt_cache):
|
|
37
|
+
print(response.text.replace("\0", "").replace("\x1e", ""), end="", flush=True)
|
|
38
|
+
last_response = response
|
|
39
|
+
|
|
40
|
+
meta: dict = {}
|
|
41
|
+
if last_response is not None:
|
|
42
|
+
if hasattr(last_response, "prompt_tokens"):
|
|
43
|
+
meta["prompt_tokens"] = last_response.prompt_tokens
|
|
44
|
+
if hasattr(last_response, "generation_tokens"):
|
|
45
|
+
meta["generation_tokens"] = last_response.generation_tokens
|
|
46
|
+
|
|
47
|
+
if meta:
|
|
48
|
+
print(f"\x1e__META__:{json.dumps(meta)}", end="\0", flush=True)
|
|
49
|
+
else:
|
|
50
|
+
print("", end="\0", flush=True)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def handle_chat(
|
|
54
|
+
backend: ModelBackend,
|
|
55
|
+
capabilities: dict,
|
|
56
|
+
messages: list,
|
|
57
|
+
primer: str | None = None,
|
|
58
|
+
options: dict | None = None,
|
|
59
|
+
tools: list | None = None,
|
|
60
|
+
images: list | None = None,
|
|
61
|
+
max_image_size: int = 768,
|
|
62
|
+
reasoning_effort: str | None = None,
|
|
63
|
+
cache_path: str | None = None,
|
|
64
|
+
cache_trim_tokens: int | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""chat API の処理"""
|
|
67
|
+
if options is None:
|
|
68
|
+
options = {}
|
|
69
|
+
|
|
70
|
+
tokenizer = backend.get_tokenizer()
|
|
71
|
+
|
|
72
|
+
if backend.supports_vision():
|
|
73
|
+
add_generation_prompt = True
|
|
74
|
+
fmt_messages = list(messages)
|
|
75
|
+
if primer is not None:
|
|
76
|
+
fmt_messages.append({"role": "assistant", "content": primer})
|
|
77
|
+
add_generation_prompt = False
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
prompt = tokenizer.apply_chat_template(
|
|
81
|
+
fmt_messages,
|
|
82
|
+
tools=tools,
|
|
83
|
+
add_generation_prompt=add_generation_prompt,
|
|
84
|
+
tokenize=False,
|
|
85
|
+
)
|
|
86
|
+
except TypeError:
|
|
87
|
+
prompt = tokenizer.apply_chat_template(
|
|
88
|
+
fmt_messages,
|
|
89
|
+
add_generation_prompt=add_generation_prompt,
|
|
90
|
+
tokenize=False,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if primer is not None:
|
|
94
|
+
prompt = primer.join(prompt.split(primer)[0:-1]) + primer
|
|
95
|
+
|
|
96
|
+
display_prompt = re.sub(r'(<\|image_pad\|>)+', '<|image_pad|>...', prompt)
|
|
97
|
+
sys.stderr.write(f"--- vlm prompt (images: {len(images) if images else 0}, max_size: {max_image_size})\n{display_prompt}\n")
|
|
98
|
+
|
|
99
|
+
final_options = dict(options)
|
|
100
|
+
final_options["max_image_size"] = max_image_size
|
|
101
|
+
_stream_to_stdout(
|
|
102
|
+
backend,
|
|
103
|
+
prompt,
|
|
104
|
+
final_options,
|
|
105
|
+
images=images,
|
|
106
|
+
primer=primer,
|
|
107
|
+
)
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
prompt_cache = backend.load_cache_from_file(cache_path) if cache_path else None
|
|
111
|
+
cache_tokens = 0
|
|
112
|
+
if prompt_cache is not None:
|
|
113
|
+
if cache_trim_tokens is not None:
|
|
114
|
+
current_offset = backend.get_cache_offset(prompt_cache)
|
|
115
|
+
if current_offset > cache_trim_tokens:
|
|
116
|
+
trim_prompt_cache(prompt_cache, current_offset - cache_trim_tokens)
|
|
117
|
+
sys.stderr.write(
|
|
118
|
+
f"KV cache trimmed: {current_offset} → {cache_trim_tokens} tokens\n"
|
|
119
|
+
)
|
|
120
|
+
cache_tokens = cache_trim_tokens
|
|
121
|
+
else:
|
|
122
|
+
cache_tokens = current_offset
|
|
123
|
+
else:
|
|
124
|
+
meta_count = _read_cache_token_count(cache_path) if cache_path else None
|
|
125
|
+
if meta_count is not None:
|
|
126
|
+
cache_tokens = meta_count
|
|
127
|
+
else:
|
|
128
|
+
# Legacy cache without meta file - skip it for safety
|
|
129
|
+
sys.stderr.write(
|
|
130
|
+
f"WARNING: Cache file exists but no .meta.json found at {cache_path}. "
|
|
131
|
+
"Ignoring cache for safety (may be from old implementation).\n"
|
|
132
|
+
)
|
|
133
|
+
prompt_cache = None
|
|
134
|
+
cache_tokens = 0
|
|
135
|
+
if prompt_cache is not None:
|
|
136
|
+
sys.stderr.write(
|
|
137
|
+
f"KV cache loaded: {len(prompt_cache)} layers, {cache_tokens} cached tokens\n"
|
|
138
|
+
)
|
|
139
|
+
elif cache_path:
|
|
140
|
+
sys.stderr.write(f"KV cache load FAILED: {cache_path}\n")
|
|
141
|
+
|
|
142
|
+
if not supports_chat_template(tokenizer):
|
|
143
|
+
prompt = generate_merged_prompt(messages, capabilities)
|
|
144
|
+
if prompt_cache is not None:
|
|
145
|
+
sys.stderr.write("KV cache ignored: model does not support chat template\n")
|
|
146
|
+
_stream_to_stdout(backend, prompt, options, primer=primer)
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
add_generation_prompt = True
|
|
150
|
+
fmt_messages = list(messages)
|
|
151
|
+
if primer is not None:
|
|
152
|
+
fmt_messages.append({"role": "assistant", "content": primer})
|
|
153
|
+
add_generation_prompt = False
|
|
154
|
+
|
|
155
|
+
extra_kwargs = {}
|
|
156
|
+
if tools is not None:
|
|
157
|
+
extra_kwargs["tools"] = tools
|
|
158
|
+
if reasoning_effort is not None:
|
|
159
|
+
extra_kwargs["reasoning_effort"] = reasoning_effort
|
|
160
|
+
|
|
161
|
+
trust_remote_code = options.get("trust_remote_code")
|
|
162
|
+
if trust_remote_code is not None:
|
|
163
|
+
extra_kwargs["trust_remote_code"] = trust_remote_code
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
prompt = tokenizer.apply_chat_template(
|
|
167
|
+
fmt_messages,
|
|
168
|
+
add_generation_prompt=add_generation_prompt,
|
|
169
|
+
tokenize=False,
|
|
170
|
+
**extra_kwargs,
|
|
171
|
+
)
|
|
172
|
+
except TypeError:
|
|
173
|
+
try:
|
|
174
|
+
fallback_kwargs = {}
|
|
175
|
+
if tools is not None:
|
|
176
|
+
fallback_kwargs["tools"] = tools
|
|
177
|
+
prompt = tokenizer.apply_chat_template(
|
|
178
|
+
fmt_messages,
|
|
179
|
+
add_generation_prompt=add_generation_prompt,
|
|
180
|
+
tokenize=False,
|
|
181
|
+
**fallback_kwargs,
|
|
182
|
+
)
|
|
183
|
+
except TypeError:
|
|
184
|
+
prompt = tokenizer.apply_chat_template(
|
|
185
|
+
fmt_messages,
|
|
186
|
+
add_generation_prompt=add_generation_prompt,
|
|
187
|
+
tokenize=False,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if primer is not None:
|
|
191
|
+
prompt = primer.join(prompt.split(primer)[0:-1]) + primer
|
|
192
|
+
|
|
193
|
+
if isinstance(prompt, list):
|
|
194
|
+
sys.stderr.write(f"--- prompt: len={len(prompt)}\n")
|
|
195
|
+
else:
|
|
196
|
+
sys.stderr.write(f"--- prompt\n{prompt}\n")
|
|
197
|
+
|
|
198
|
+
final_options = dict(options)
|
|
199
|
+
final_options.pop("trust_remote_code", None)
|
|
200
|
+
|
|
201
|
+
effective_prompt = prompt
|
|
202
|
+
if prompt_cache is not None and cache_tokens > 0 and isinstance(prompt, str):
|
|
203
|
+
add_special = tokenizer.bos_token is None or not prompt.startswith(
|
|
204
|
+
tokenizer.bos_token
|
|
205
|
+
)
|
|
206
|
+
full_tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
|
207
|
+
|
|
208
|
+
if cache_tokens < len(full_tokens):
|
|
209
|
+
effective_prompt = full_tokens[cache_tokens:]
|
|
210
|
+
sys.stderr.write(
|
|
211
|
+
f"Prefilled {cache_tokens}/{len(full_tokens)} tokens, "
|
|
212
|
+
f"generating from {len(effective_prompt)} remaining\n"
|
|
213
|
+
)
|
|
214
|
+
else:
|
|
215
|
+
sys.stderr.write(
|
|
216
|
+
f"Prefill offset {cache_tokens} >= prompt {len(full_tokens)}, "
|
|
217
|
+
f"ignoring prefill state\n"
|
|
218
|
+
)
|
|
219
|
+
prompt_cache = None
|
|
220
|
+
|
|
221
|
+
_stream_to_stdout(backend, effective_prompt, final_options, primer=primer, prompt_cache=prompt_cache)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
from backends.base import ModelBackend
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def handle_completion(
|
|
11
|
+
backend: ModelBackend,
|
|
12
|
+
prompt: str | list[int],
|
|
13
|
+
options: dict | None = None,
|
|
14
|
+
images: list | None = None,
|
|
15
|
+
max_image_size: int = 768,
|
|
16
|
+
) -> None:
|
|
17
|
+
"""completion API の処理"""
|
|
18
|
+
if options is None:
|
|
19
|
+
options = {}
|
|
20
|
+
|
|
21
|
+
final_options = dict(options)
|
|
22
|
+
if images:
|
|
23
|
+
final_options["max_image_size"] = max_image_size
|
|
24
|
+
if os.getenv('MLX_DEBUG'):
|
|
25
|
+
display_prompt = re.sub(r'(<\|image_pad\|>)+', '<|image_pad|>...', prompt)
|
|
26
|
+
sys.stderr.write(f"--- vlm completion (images: {len(images)}, max_size: {max_image_size})\n{display_prompt}\n")
|
|
27
|
+
elif os.getenv('MLX_DEBUG'):
|
|
28
|
+
if isinstance(prompt, list):
|
|
29
|
+
sys.stderr.write(f"--- prompt: len={len(prompt)}\n")
|
|
30
|
+
else:
|
|
31
|
+
sys.stderr.write(f"--- prompt\n{prompt}\n")
|
|
32
|
+
|
|
33
|
+
for response in backend.stream_generate(prompt, final_options, images):
|
|
34
|
+
print(response.text.replace("\0", ""), end="", flush=True)
|
|
35
|
+
|
|
36
|
+
print("\n", end="\0", flush=True)
|