@modular-prompt/driver 0.11.15 → 0.13.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +25 -0
- package/dist/anthropic/anthropic-driver.d.ts +38 -8
- package/dist/anthropic/anthropic-driver.d.ts.map +1 -1
- package/dist/anthropic/anthropic-driver.js +180 -164
- package/dist/anthropic/anthropic-driver.js.map +1 -1
- package/dist/cache-controller.d.ts +28 -0
- package/dist/cache-controller.d.ts.map +1 -0
- package/dist/cache-controller.js +2 -0
- package/dist/cache-controller.js.map +1 -0
- package/dist/cache-utils.d.ts +20 -0
- package/dist/cache-utils.d.ts.map +1 -0
- package/dist/cache-utils.js +71 -0
- package/dist/cache-utils.js.map +1 -0
- package/dist/content-utils.d.ts +9 -0
- package/dist/content-utils.d.ts.map +1 -1
- package/dist/content-utils.js +47 -0
- package/dist/content-utils.js.map +1 -1
- package/dist/driver-registry/config-based-factory.d.ts.map +1 -1
- package/dist/driver-registry/config-based-factory.js +7 -0
- package/dist/driver-registry/config-based-factory.js.map +1 -1
- package/dist/driver-registry/factory-helper.d.ts.map +1 -1
- package/dist/driver-registry/factory-helper.js +7 -4
- package/dist/driver-registry/factory-helper.js.map +1 -1
- package/dist/driver-registry/types.d.ts +6 -0
- package/dist/driver-registry/types.d.ts.map +1 -1
- package/dist/formatter/converter.js +1 -1
- package/dist/formatter/converter.js.map +1 -1
- package/dist/google-genai/element-converter.d.ts +11 -0
- package/dist/google-genai/element-converter.d.ts.map +1 -0
- package/dist/google-genai/element-converter.js +126 -0
- package/dist/google-genai/element-converter.js.map +1 -0
- package/dist/google-genai/google-genai-cache-controller.d.ts +24 -0
- package/dist/google-genai/google-genai-cache-controller.d.ts.map +1 -0
- package/dist/google-genai/google-genai-cache-controller.js +127 -0
- package/dist/google-genai/google-genai-cache-controller.js.map +1 -0
- package/dist/google-genai/google-genai-driver.d.ts +5 -29
- package/dist/google-genai/google-genai-driver.d.ts.map +1 -1
- package/dist/google-genai/google-genai-driver.js +92 -255
- package/dist/google-genai/google-genai-driver.js.map +1 -1
- package/dist/index.d.ts +4 -0
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +3 -0
- package/dist/index.js.map +1 -1
- package/dist/mlx-ml/mlx-cache-controller.d.ts +66 -0
- package/dist/mlx-ml/mlx-cache-controller.d.ts.map +1 -0
- package/dist/mlx-ml/mlx-cache-controller.js +600 -0
- package/dist/mlx-ml/mlx-cache-controller.js.map +1 -0
- package/dist/mlx-ml/mlx-driver.d.ts +13 -8
- package/dist/mlx-ml/mlx-driver.d.ts.map +1 -1
- package/dist/mlx-ml/mlx-driver.js +202 -143
- package/dist/mlx-ml/mlx-driver.js.map +1 -1
- package/dist/mlx-ml/mlx-message-utils.d.ts +9 -0
- package/dist/mlx-ml/mlx-message-utils.d.ts.map +1 -0
- package/dist/mlx-ml/mlx-message-utils.js +71 -0
- package/dist/mlx-ml/mlx-message-utils.js.map +1 -0
- package/dist/mlx-ml/process/harmony-parser.d.ts +3 -0
- package/dist/mlx-ml/process/harmony-parser.d.ts.map +1 -0
- package/dist/mlx-ml/process/harmony-parser.js +175 -0
- package/dist/mlx-ml/process/harmony-parser.js.map +1 -0
- package/dist/mlx-ml/process/index.d.ts +7 -3
- package/dist/mlx-ml/process/index.d.ts.map +1 -1
- package/dist/mlx-ml/process/index.js +22 -7
- package/dist/mlx-ml/process/index.js.map +1 -1
- package/dist/mlx-ml/process/model-handlers.d.ts +11 -58
- package/dist/mlx-ml/process/model-handlers.d.ts.map +1 -1
- package/dist/mlx-ml/process/model-handlers.js +29 -11
- package/dist/mlx-ml/process/model-handlers.js.map +1 -1
- package/dist/mlx-ml/process/model-specific.d.ts +7 -0
- package/dist/mlx-ml/process/model-specific.d.ts.map +1 -1
- package/dist/mlx-ml/process/model-specific.js +3 -0
- package/dist/mlx-ml/process/model-specific.js.map +1 -1
- package/dist/mlx-ml/process/parameter-validator.d.ts.map +1 -1
- package/dist/mlx-ml/process/parameter-validator.js +10 -3
- package/dist/mlx-ml/process/parameter-validator.js.map +1 -1
- package/dist/mlx-ml/process/process-communication.d.ts +3 -0
- package/dist/mlx-ml/process/process-communication.d.ts.map +1 -1
- package/dist/mlx-ml/process/process-communication.js +13 -0
- package/dist/mlx-ml/process/process-communication.js.map +1 -1
- package/dist/mlx-ml/process/queue.d.ts +5 -2
- package/dist/mlx-ml/process/queue.d.ts.map +1 -1
- package/dist/mlx-ml/process/queue.js +103 -15
- package/dist/mlx-ml/process/queue.js.map +1 -1
- package/dist/mlx-ml/process/response-processor.d.ts +18 -0
- package/dist/mlx-ml/process/response-processor.d.ts.map +1 -0
- package/dist/mlx-ml/process/response-processor.js +24 -0
- package/dist/mlx-ml/process/response-processor.js.map +1 -0
- package/dist/mlx-ml/process/types.d.ts +51 -4
- package/dist/mlx-ml/process/types.d.ts.map +1 -1
- package/dist/mlx-ml/tool-call-parser.d.ts.map +1 -1
- package/dist/mlx-ml/tool-call-parser.js +44 -68
- package/dist/mlx-ml/tool-call-parser.js.map +1 -1
- package/dist/mlx-ml/types.d.ts +1 -0
- package/dist/mlx-ml/types.d.ts.map +1 -1
- package/dist/openai/openai-driver.d.ts +0 -2
- package/dist/openai/openai-driver.d.ts.map +1 -1
- package/dist/openai/openai-driver.js.map +1 -1
- package/dist/types.d.ts +9 -0
- package/dist/types.d.ts.map +1 -1
- package/package.json +7 -4
- package/src/mlx-ml/python/__main__.py +41 -425
- package/src/mlx-ml/python/backends/__init__.py +3 -0
- package/src/mlx-ml/python/backends/base.py +84 -0
- package/src/mlx-ml/python/backends/mlx_lm.py +202 -0
- package/src/mlx-ml/python/backends/mlx_vlm.py +99 -0
- package/src/mlx-ml/python/examples/example_basic.py +93 -0
- package/src/mlx-ml/python/examples/example_tool_call.py +165 -0
- package/src/mlx-ml/python/handlers/__init__.py +6 -0
- package/src/mlx-ml/python/handlers/cache.py +81 -0
- package/src/mlx-ml/python/handlers/capabilities.py +6 -0
- package/src/mlx-ml/python/handlers/chat.py +221 -0
- package/src/mlx-ml/python/handlers/completion.py +36 -0
- package/src/mlx-ml/python/handlers/format_test.py +70 -0
- package/src/mlx-ml/python/handlers/tokenize.py +63 -0
- package/src/mlx-ml/python/pyproject.toml +15 -5
- package/src/mlx-ml/python/server.py +126 -0
- package/src/mlx-ml/python/tests/__init__.py +0 -0
- package/src/mlx-ml/python/utils/__init__.py +0 -0
- package/src/mlx-ml/python/utils/prompt_builder.py +54 -0
- package/src/mlx-ml/python/{token_utils.py → utils/token_utils.py} +13 -5
- package/src/mlx-ml/python/uv.lock +299 -57
- /package/src/mlx-ml/python/{chat_template_constraints.py → utils/chat_template_constraints.py} +0 -0
- /package/src/mlx-ml/python/{vlm_utils.py → utils/vlm_utils.py} +0 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any, Iterator
|
|
7
|
+
|
|
8
|
+
from mlx_lm import load as mlx_lm_load
|
|
9
|
+
from mlx_lm import stream_generate as mlx_lm_stream_generate
|
|
10
|
+
from mlx_lm.models.cache import make_prompt_cache, save_prompt_cache, load_prompt_cache, trim_prompt_cache
|
|
11
|
+
from mlx_lm.sample_utils import make_sampler
|
|
12
|
+
|
|
13
|
+
from backends.base import ModelBackend
|
|
14
|
+
from utils.token_utils import is_eod_token
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MlxLmBackend(ModelBackend):
|
|
18
|
+
"""`mlx_lm` backend for text-only models."""
|
|
19
|
+
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
self.model: Any | None = None
|
|
22
|
+
self.tokenizer: Any | None = None
|
|
23
|
+
|
|
24
|
+
def load(self, model_name: str) -> None:
|
|
25
|
+
self.model, self.tokenizer = mlx_lm_load(model_name)
|
|
26
|
+
|
|
27
|
+
def get_tokenizer(self) -> Any:
|
|
28
|
+
return self.tokenizer
|
|
29
|
+
|
|
30
|
+
def stream_generate(
|
|
31
|
+
self,
|
|
32
|
+
prompt: str | list[int],
|
|
33
|
+
options: dict,
|
|
34
|
+
images: list | None = None,
|
|
35
|
+
prompt_cache: list | None = None,
|
|
36
|
+
) -> Iterator[Any]:
|
|
37
|
+
if self.model is None or self.tokenizer is None:
|
|
38
|
+
raise RuntimeError("Model is not loaded")
|
|
39
|
+
|
|
40
|
+
final_options = {"max_tokens": 1000, **options}
|
|
41
|
+
temperature = final_options.pop("temperature", 1.0)
|
|
42
|
+
top_p = final_options.pop("top_p", 0.0)
|
|
43
|
+
top_k = final_options.pop("top_k", 0)
|
|
44
|
+
final_options["sampler"] = make_sampler(
|
|
45
|
+
temp=temperature,
|
|
46
|
+
top_p=top_p,
|
|
47
|
+
top_k=top_k,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if prompt_cache is not None:
|
|
51
|
+
final_options["prompt_cache"] = prompt_cache
|
|
52
|
+
|
|
53
|
+
for response in mlx_lm_stream_generate(
|
|
54
|
+
self.model,
|
|
55
|
+
self.tokenizer,
|
|
56
|
+
prompt,
|
|
57
|
+
**final_options,
|
|
58
|
+
):
|
|
59
|
+
if is_eod_token(response, self.tokenizer):
|
|
60
|
+
break
|
|
61
|
+
yield response
|
|
62
|
+
|
|
63
|
+
# get_cache_offset is inherited from ModelBackend base class
|
|
64
|
+
|
|
65
|
+
def _tokenize_prompt(self, prompt: str) -> list[int]:
|
|
66
|
+
"""Tokenize a prompt string using the same logic as stream_generate."""
|
|
67
|
+
add_special = self.tokenizer.bos_token is None or not prompt.startswith(
|
|
68
|
+
self.tokenizer.bos_token
|
|
69
|
+
)
|
|
70
|
+
return self.tokenizer.encode(prompt, add_special_tokens=add_special)
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def _write_cache_meta(
|
|
74
|
+
cache_path: str,
|
|
75
|
+
token_count: int,
|
|
76
|
+
prefix_offsets: list[int] | None = None,
|
|
77
|
+
prefix_hashes: list[str] | None = None,
|
|
78
|
+
) -> None:
|
|
79
|
+
meta_path = cache_path + '.meta.json'
|
|
80
|
+
try:
|
|
81
|
+
meta: dict[str, Any] = {"token_count": token_count}
|
|
82
|
+
if prefix_offsets and prefix_hashes:
|
|
83
|
+
meta["prefix_offsets"] = prefix_offsets
|
|
84
|
+
meta["prefix_hashes"] = prefix_hashes
|
|
85
|
+
with open(meta_path, 'w') as f:
|
|
86
|
+
json.dump(meta, f)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
sys.stderr.write(f"Failed to write cache meta: {e}\n")
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def _read_cache_meta(cache_path: str) -> int | None:
|
|
92
|
+
meta_path = cache_path + '.meta.json'
|
|
93
|
+
try:
|
|
94
|
+
with open(meta_path) as f:
|
|
95
|
+
meta = json.load(f)
|
|
96
|
+
count = meta.get('token_count')
|
|
97
|
+
return int(count) if count is not None else None
|
|
98
|
+
except (FileNotFoundError, json.JSONDecodeError, ValueError, TypeError):
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
def cache_prefill(
|
|
102
|
+
self,
|
|
103
|
+
cache_path: str,
|
|
104
|
+
prompt: str,
|
|
105
|
+
base_cache_path: str | None = None,
|
|
106
|
+
trim_to_tokens: int | None = None,
|
|
107
|
+
prefix_offsets: list[int] | None = None,
|
|
108
|
+
prefix_hashes: list[str] | None = None,
|
|
109
|
+
) -> dict:
|
|
110
|
+
if self.model is None or self.tokenizer is None:
|
|
111
|
+
raise RuntimeError("Model is not loaded")
|
|
112
|
+
|
|
113
|
+
full_tokens = self._tokenize_prompt(prompt)
|
|
114
|
+
token_count = len(full_tokens)
|
|
115
|
+
effective_prompt: str | list[int] = prompt
|
|
116
|
+
|
|
117
|
+
if base_cache_path is not None:
|
|
118
|
+
try:
|
|
119
|
+
prompt_cache = load_prompt_cache(base_cache_path)
|
|
120
|
+
|
|
121
|
+
if trim_to_tokens is not None:
|
|
122
|
+
current_offset = self.get_cache_offset(prompt_cache)
|
|
123
|
+
if current_offset > trim_to_tokens:
|
|
124
|
+
trim_count = current_offset - trim_to_tokens
|
|
125
|
+
trim_prompt_cache(prompt_cache, trim_count)
|
|
126
|
+
if os.getenv('MLX_DEBUG'):
|
|
127
|
+
sys.stderr.write(
|
|
128
|
+
f"Trimmed base cache: {current_offset} → {trim_to_tokens} tokens\n"
|
|
129
|
+
)
|
|
130
|
+
cache_offset = trim_to_tokens
|
|
131
|
+
else:
|
|
132
|
+
cache_offset = current_offset
|
|
133
|
+
else:
|
|
134
|
+
meta_offset = self._read_cache_meta(base_cache_path)
|
|
135
|
+
if meta_offset is None:
|
|
136
|
+
# Legacy cache without meta file - create fresh cache
|
|
137
|
+
sys.stderr.write(
|
|
138
|
+
f"WARNING: Cache file exists but no .meta.json found at {base_cache_path}. "
|
|
139
|
+
"Creating fresh cache (may be from old implementation).\n"
|
|
140
|
+
)
|
|
141
|
+
cache_offset = 0
|
|
142
|
+
prompt_cache = make_prompt_cache(self.model)
|
|
143
|
+
else:
|
|
144
|
+
cache_offset = meta_offset
|
|
145
|
+
|
|
146
|
+
if cache_offset > 0 and prompt_cache is not None:
|
|
147
|
+
if cache_offset < token_count:
|
|
148
|
+
effective_prompt = full_tokens[cache_offset:]
|
|
149
|
+
if os.getenv('MLX_DEBUG'):
|
|
150
|
+
sys.stderr.write(
|
|
151
|
+
f"Incremental prefill from: {base_cache_path} "
|
|
152
|
+
f"(skip {cache_offset}/{token_count} tokens)\n"
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
if os.getenv('MLX_DEBUG'):
|
|
156
|
+
sys.stderr.write(
|
|
157
|
+
f"Base cache covers entire prompt "
|
|
158
|
+
f"({cache_offset} >= {token_count}), saving as-is\n"
|
|
159
|
+
)
|
|
160
|
+
save_prompt_cache(cache_path, prompt_cache)
|
|
161
|
+
self._write_cache_meta(cache_path, token_count, prefix_offsets, prefix_hashes)
|
|
162
|
+
return {"cache_path": cache_path, "token_count": token_count}
|
|
163
|
+
else:
|
|
164
|
+
if os.getenv('MLX_DEBUG'):
|
|
165
|
+
sys.stderr.write(f"Incremental prefill from: {base_cache_path}\n")
|
|
166
|
+
except Exception as e:
|
|
167
|
+
sys.stderr.write(f"Base cache load failed, creating fresh: {e}\n")
|
|
168
|
+
prompt_cache = make_prompt_cache(self.model)
|
|
169
|
+
else:
|
|
170
|
+
prompt_cache = make_prompt_cache(self.model)
|
|
171
|
+
|
|
172
|
+
if os.getenv('MLX_DEBUG'):
|
|
173
|
+
sys.stderr.write(f"Prefill prompt: {token_count} tokens\n")
|
|
174
|
+
|
|
175
|
+
for _ in mlx_lm_stream_generate(
|
|
176
|
+
self.model, self.tokenizer, effective_prompt,
|
|
177
|
+
prompt_cache=prompt_cache, max_tokens=1,
|
|
178
|
+
):
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
save_prompt_cache(cache_path, prompt_cache)
|
|
182
|
+
self._write_cache_meta(cache_path, token_count, prefix_offsets, prefix_hashes)
|
|
183
|
+
if os.getenv('MLX_DEBUG'):
|
|
184
|
+
sys.stderr.write(f"Cache created: {cache_path} ({token_count} tokens)\n")
|
|
185
|
+
return {"cache_path": cache_path, "token_count": token_count}
|
|
186
|
+
|
|
187
|
+
def load_cache_from_file(self, cache_path: str) -> list | None:
|
|
188
|
+
try:
|
|
189
|
+
return load_prompt_cache(cache_path)
|
|
190
|
+
except FileNotFoundError:
|
|
191
|
+
sys.stderr.write(f"Cache file not found: {cache_path}\n")
|
|
192
|
+
return None
|
|
193
|
+
except Exception as e:
|
|
194
|
+
sys.stderr.write(f"Failed to load cache: {e}\n")
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
def supports_vision(self) -> bool:
|
|
198
|
+
return False
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def model_kind(self) -> str:
|
|
202
|
+
return "lm"
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Any, Iterator
|
|
5
|
+
|
|
6
|
+
from mlx_vlm import load as mlx_vlm_load
|
|
7
|
+
from mlx_vlm import stream_generate as mlx_vlm_stream_generate
|
|
8
|
+
|
|
9
|
+
from backends.base import ModelBackend
|
|
10
|
+
from utils.vlm_utils import load_and_resize_images
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MlxVlmBackend(ModelBackend):
|
|
14
|
+
"""`mlx_vlm` backend for vision-language models."""
|
|
15
|
+
|
|
16
|
+
def __init__(self) -> None:
|
|
17
|
+
self.model: Any | None = None
|
|
18
|
+
self.processor: Any | None = None
|
|
19
|
+
self.drafter: Any | None = None
|
|
20
|
+
self.drafter_kind: str | None = None
|
|
21
|
+
self.draft_block_size: int | None = None
|
|
22
|
+
|
|
23
|
+
def load(self, model_name: str) -> None:
|
|
24
|
+
self.model, self.processor = mlx_vlm_load(model_name)
|
|
25
|
+
|
|
26
|
+
def load_drafter(self, drafter_model: str) -> None:
|
|
27
|
+
from mlx_vlm.speculative.drafters import load_drafter
|
|
28
|
+
self.drafter, self.drafter_kind = load_drafter(drafter_model)
|
|
29
|
+
sys.stderr.write(f"Drafter loaded: {drafter_model} (kind={self.drafter_kind})\n")
|
|
30
|
+
|
|
31
|
+
def has_drafter(self) -> bool:
|
|
32
|
+
return self.drafter is not None
|
|
33
|
+
|
|
34
|
+
def get_tokenizer(self) -> Any:
|
|
35
|
+
return self.processor
|
|
36
|
+
|
|
37
|
+
def stream_generate(
|
|
38
|
+
self, prompt: str | list[int], options: dict, images: list | None = None,
|
|
39
|
+
prompt_cache: list | None = None,
|
|
40
|
+
) -> Iterator[Any]:
|
|
41
|
+
if self.model is None or self.processor is None:
|
|
42
|
+
raise RuntimeError("Model is not loaded")
|
|
43
|
+
|
|
44
|
+
final_options = dict(options)
|
|
45
|
+
temperature = final_options.pop("temperature", 1.0)
|
|
46
|
+
max_tokens = final_options.pop("max_tokens", 1000)
|
|
47
|
+
top_p = final_options.pop("top_p", 0.0)
|
|
48
|
+
top_k = final_options.pop("top_k", 0)
|
|
49
|
+
|
|
50
|
+
processed_images = None
|
|
51
|
+
if images:
|
|
52
|
+
max_image_size = final_options.pop("max_image_size", 768)
|
|
53
|
+
processed_images = load_and_resize_images(images, max_image_size)
|
|
54
|
+
|
|
55
|
+
draft_kwargs = {}
|
|
56
|
+
if self.drafter:
|
|
57
|
+
draft_kwargs["draft_model"] = self.drafter
|
|
58
|
+
draft_kwargs["draft_kind"] = self.drafter_kind
|
|
59
|
+
if self.draft_block_size is not None:
|
|
60
|
+
draft_kwargs["draft_block_size"] = self.draft_block_size
|
|
61
|
+
|
|
62
|
+
# Generate and collect tokens
|
|
63
|
+
token_count = 0
|
|
64
|
+
for result in mlx_vlm_stream_generate(
|
|
65
|
+
self.model,
|
|
66
|
+
self.processor,
|
|
67
|
+
prompt,
|
|
68
|
+
image=processed_images,
|
|
69
|
+
max_tokens=max_tokens,
|
|
70
|
+
temperature=temperature,
|
|
71
|
+
top_p=top_p,
|
|
72
|
+
top_k=top_k,
|
|
73
|
+
**draft_kwargs,
|
|
74
|
+
):
|
|
75
|
+
token_count += 1
|
|
76
|
+
yield result
|
|
77
|
+
|
|
78
|
+
# Output speculative decoding stats if drafter was used
|
|
79
|
+
if self.drafter and hasattr(self.drafter, 'accept_lens'):
|
|
80
|
+
accept_lens = self.drafter.accept_lens
|
|
81
|
+
if accept_lens:
|
|
82
|
+
avg_accepted = sum(accept_lens) / len(accept_lens)
|
|
83
|
+
# Show stats unless MLX_NO_STATS environment variable is set
|
|
84
|
+
import os
|
|
85
|
+
if not os.getenv('MLX_NO_STATS'):
|
|
86
|
+
sys.stderr.write(f"\n[Speculative Decoding Stats]\n")
|
|
87
|
+
sys.stderr.write(f" Rounds: {len(accept_lens)}\n")
|
|
88
|
+
sys.stderr.write(f" Average accepted tokens/round: {avg_accepted:.2f}\n")
|
|
89
|
+
sys.stderr.write(f" Total tokens generated: {token_count}\n")
|
|
90
|
+
sys.stderr.write(f" Speedup factor: {avg_accepted:.2f}x (theoretical)\n")
|
|
91
|
+
# Clear for next generation
|
|
92
|
+
self.drafter.accept_lens = []
|
|
93
|
+
|
|
94
|
+
def supports_vision(self) -> bool:
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def model_kind(self) -> str:
|
|
99
|
+
return "vlm"
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# This file contains code to use LLM-jp-4 models with mlx-lm on Apple Silicon.
|
|
2
|
+
|
|
3
|
+
from mlx_lm import load, stream_generate
|
|
4
|
+
from mlx_lm.sample_utils import make_sampler
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def main():
|
|
8
|
+
model, tokenizer = load(
|
|
9
|
+
# "llm-jp/llm-jp-4-8b-instruct",
|
|
10
|
+
"llm-jp/llm-jp-4-8b-thinking",
|
|
11
|
+
tokenizer_config={"trust_remote_code": True},
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
messages = [
|
|
15
|
+
{"role": "user", "content": "日本語で自己紹介してください。"},
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
prompt: str = tokenizer.apply_chat_template(
|
|
19
|
+
messages,
|
|
20
|
+
tokenize=False,
|
|
21
|
+
add_generation_prompt=True,
|
|
22
|
+
reasoning_effort="medium",
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
print("--- Prompt ---")
|
|
26
|
+
print(prompt)
|
|
27
|
+
|
|
28
|
+
input_ids = tokenizer.encode(prompt)
|
|
29
|
+
|
|
30
|
+
print("--- Input IDs ---")
|
|
31
|
+
print(input_ids)
|
|
32
|
+
|
|
33
|
+
generated_ids: list[int] = []
|
|
34
|
+
|
|
35
|
+
sampler = make_sampler(temp=0.7, top_p=0.9)
|
|
36
|
+
|
|
37
|
+
for resp in stream_generate(
|
|
38
|
+
model, tokenizer, prompt=input_ids,
|
|
39
|
+
max_tokens=1024, sampler=sampler,
|
|
40
|
+
):
|
|
41
|
+
generated_ids.append(resp.token)
|
|
42
|
+
|
|
43
|
+
print("--- Generated IDs ---")
|
|
44
|
+
print(generated_ids)
|
|
45
|
+
|
|
46
|
+
response = tokenizer.decode(generated_ids)
|
|
47
|
+
|
|
48
|
+
print("\n--- Response ---")
|
|
49
|
+
print(response)
|
|
50
|
+
|
|
51
|
+
parsed = tokenizer.parse_response(response)
|
|
52
|
+
|
|
53
|
+
print("\n--- Parsed Response ---")
|
|
54
|
+
print("Role:", parsed.get("role"))
|
|
55
|
+
print("Thinking:", parsed.get("thinking"))
|
|
56
|
+
print("Content:", parsed.get("content"))
|
|
57
|
+
|
|
58
|
+
# Harmony parser is bundled as the parse_harmony_message method of the tokenizer.
|
|
59
|
+
# This function accepts a list of token IDs (not strings)
|
|
60
|
+
# and returns a list of Harmony's message objects with split tokens.
|
|
61
|
+
|
|
62
|
+
# To correctly parse the response,
|
|
63
|
+
# we need to include the prefill tokens for the assistant's response.
|
|
64
|
+
response_prefill = tokenizer.encode("<|start|>assistant")
|
|
65
|
+
parsed_harmony = tokenizer.parse_harmony_message(response_prefill + generated_ids)
|
|
66
|
+
|
|
67
|
+
print("\n--- Parsed Harmony Messages ---")
|
|
68
|
+
for i, message in enumerate(parsed_harmony, start=1):
|
|
69
|
+
print(f"Message {i}:")
|
|
70
|
+
|
|
71
|
+
# The end type can be "END", "CALL", or "INCOMPLETE".
|
|
72
|
+
print(" End Type:", message.end)
|
|
73
|
+
|
|
74
|
+
if message.role:
|
|
75
|
+
print(" Role Tokens:", message.role.token_ids)
|
|
76
|
+
print(" Role Text:", repr(tokenizer.decode(message.role.token_ids)))
|
|
77
|
+
print(" Role Start Position:", message.role.start)
|
|
78
|
+
if message.channel:
|
|
79
|
+
print(" Channel Tokens:", message.channel.token_ids)
|
|
80
|
+
print(" Channel Text:", repr(tokenizer.decode(message.channel.token_ids)))
|
|
81
|
+
print(" Channel Start Position:", message.channel.start)
|
|
82
|
+
if message.constrain:
|
|
83
|
+
print(" Constrain Tokens:", message.constrain.token_ids)
|
|
84
|
+
print(" Constrain Text:", repr(tokenizer.decode(message.constrain.token_ids)))
|
|
85
|
+
print(" Constrain Start Position:", message.constrain.start)
|
|
86
|
+
if message.content:
|
|
87
|
+
print(" Content Tokens:", message.content.token_ids)
|
|
88
|
+
print(" Content Text:", repr(tokenizer.decode(message.content.token_ids)))
|
|
89
|
+
print(" Content Start Position:", message.content.start)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == "__main__":
|
|
93
|
+
main()
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Tool call example using LLM-jp-4 with mlx-lm on Apple Silicon.
|
|
2
|
+
|
|
3
|
+
from mlx_lm import load, stream_generate
|
|
4
|
+
from mlx_lm.sample_utils import make_sampler
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def generate_response(model, tokenizer, input_ids, sampler):
|
|
8
|
+
generated_ids: list[int] = []
|
|
9
|
+
for resp in stream_generate(
|
|
10
|
+
model, tokenizer, prompt=input_ids,
|
|
11
|
+
max_tokens=1024, sampler=sampler,
|
|
12
|
+
):
|
|
13
|
+
generated_ids.append(resp.token)
|
|
14
|
+
return generated_ids
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def main():
|
|
18
|
+
model, tokenizer = load(
|
|
19
|
+
# "llm-jp/llm-jp-4-8b-thinking",
|
|
20
|
+
# "llm-jp/llm-jp-4-8b-instruct",
|
|
21
|
+
# "mlx-community/llm-jp-4-32b-a3b-thinking-4bit",
|
|
22
|
+
"mlx-community/Qwen3.6-27B-4bit",
|
|
23
|
+
# tokenizer_config={"trust_remote_code": True},
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
tools = [
|
|
27
|
+
{
|
|
28
|
+
"type": "function",
|
|
29
|
+
"function": {
|
|
30
|
+
"name": "get_current_time",
|
|
31
|
+
"description": "現在の日時を取得する",
|
|
32
|
+
"parameters": {
|
|
33
|
+
"type": "object",
|
|
34
|
+
"properties": {},
|
|
35
|
+
},
|
|
36
|
+
},
|
|
37
|
+
},
|
|
38
|
+
{
|
|
39
|
+
"type": "function",
|
|
40
|
+
"function": {
|
|
41
|
+
"name": "get_weather",
|
|
42
|
+
"description": "指定された都市の現在の天気を取得する",
|
|
43
|
+
"parameters": {
|
|
44
|
+
"type": "object",
|
|
45
|
+
"required": ["city"],
|
|
46
|
+
"properties": {
|
|
47
|
+
"city": {
|
|
48
|
+
"type": "string",
|
|
49
|
+
"description": "都市名(例: 東京、大阪)",
|
|
50
|
+
},
|
|
51
|
+
},
|
|
52
|
+
},
|
|
53
|
+
},
|
|
54
|
+
},
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
messages = [
|
|
58
|
+
{"role": "developer", "content": "必要に応じて応答をTool Callに切り替えてください。functionsで定義されている機能を呼び出すことができます。"},
|
|
59
|
+
# \nツール実行形式: <|start|>assistant to=functions.get_current_time<|channel|>commentary json<|message|>{"locate": "Asia/Tokyo"}<|call|>
|
|
60
|
+
|
|
61
|
+
# few-shot: tool call → tool response の例
|
|
62
|
+
{"role": "user", "content": "今何時?"},
|
|
63
|
+
{
|
|
64
|
+
"role": "assistant",
|
|
65
|
+
"tool_calls": [{
|
|
66
|
+
"function": {
|
|
67
|
+
"name": "get_current_time",
|
|
68
|
+
"arguments": {"locate": "Asia/Tokyo"},
|
|
69
|
+
},
|
|
70
|
+
}],
|
|
71
|
+
},
|
|
72
|
+
{
|
|
73
|
+
"role": "tool",
|
|
74
|
+
"content": {"datetime": "2026-04-24T15:30:00+09:00"},
|
|
75
|
+
},
|
|
76
|
+
{
|
|
77
|
+
"role": "assistant",
|
|
78
|
+
"content": "現在の時刻は15時30分です。",
|
|
79
|
+
},
|
|
80
|
+
# 本番のリクエスト
|
|
81
|
+
{"role": "user", "content": '東京の天気を教えてください。'},
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
sampler = make_sampler(temp=0.7, top_p=0.9)
|
|
85
|
+
|
|
86
|
+
# --- Turn 1: tool call生成 ---
|
|
87
|
+
prompt: str = tokenizer.apply_chat_template(
|
|
88
|
+
messages,
|
|
89
|
+
tools=tools,
|
|
90
|
+
tokenize=False,
|
|
91
|
+
add_generation_prompt=True,
|
|
92
|
+
trust_remote_code=True,
|
|
93
|
+
reasoning_effort="middle",
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
print("=== Turn 1: Tool Call ===")
|
|
97
|
+
print("--- Prompt ---")
|
|
98
|
+
print(prompt)
|
|
99
|
+
|
|
100
|
+
input_ids = tokenizer.encode(prompt)
|
|
101
|
+
generated_ids = generate_response(model, tokenizer, input_ids, sampler)
|
|
102
|
+
response = tokenizer.decode(generated_ids)
|
|
103
|
+
|
|
104
|
+
print("\n--- Raw Response ---")
|
|
105
|
+
print(response)
|
|
106
|
+
|
|
107
|
+
# Harmony parserでtool callを解析
|
|
108
|
+
response_prefill = tokenizer.encode("<|start|>assistant")
|
|
109
|
+
parsed_harmony = tokenizer.parse_harmony_message(response_prefill + generated_ids)
|
|
110
|
+
|
|
111
|
+
print("\n--- Parsed Harmony Messages ---")
|
|
112
|
+
for i, message in enumerate(parsed_harmony, start=1):
|
|
113
|
+
print(f"Message {i}:")
|
|
114
|
+
print(" End Type:", message.end)
|
|
115
|
+
if message.role:
|
|
116
|
+
print(" Role:", repr(tokenizer.decode(message.role.token_ids)))
|
|
117
|
+
if message.channel:
|
|
118
|
+
print(" Channel:", repr(tokenizer.decode(message.channel.token_ids)))
|
|
119
|
+
if message.constrain:
|
|
120
|
+
print(" Constrain:", repr(tokenizer.decode(message.constrain.token_ids)))
|
|
121
|
+
if message.content:
|
|
122
|
+
print(" Content:", repr(tokenizer.decode(message.content.token_ids)))
|
|
123
|
+
|
|
124
|
+
# # --- Turn 2: tool resultを渡して最終応答 ---
|
|
125
|
+
# messages.append({
|
|
126
|
+
# "role": "assistant",
|
|
127
|
+
# "tool_calls": [{
|
|
128
|
+
# "function": {
|
|
129
|
+
# "name": "get_weather",
|
|
130
|
+
# "arguments": '{"city": "東京"}',
|
|
131
|
+
# },
|
|
132
|
+
# }],
|
|
133
|
+
# })
|
|
134
|
+
# messages.append({
|
|
135
|
+
# "role": "tool",
|
|
136
|
+
# "content": '{"city": "東京", "weather": "晴れ", "temperature": 22, "humidity": 45}',
|
|
137
|
+
# })
|
|
138
|
+
|
|
139
|
+
# prompt2: str = tokenizer.apply_chat_template(
|
|
140
|
+
# messages,
|
|
141
|
+
# tools=tools,
|
|
142
|
+
# tokenize=False,
|
|
143
|
+
# add_generation_prompt=True,
|
|
144
|
+
# )
|
|
145
|
+
|
|
146
|
+
# print("\n\n=== Turn 2: Final Response ===")
|
|
147
|
+
# print("--- Prompt ---")
|
|
148
|
+
# print(prompt2)
|
|
149
|
+
|
|
150
|
+
# input_ids2 = tokenizer.encode(prompt2)
|
|
151
|
+
# generated_ids2 = generate_response(model, tokenizer, input_ids2, sampler)
|
|
152
|
+
# response2 = tokenizer.decode(generated_ids2)
|
|
153
|
+
|
|
154
|
+
# print("\n--- Raw Response ---")
|
|
155
|
+
# print(response2)
|
|
156
|
+
|
|
157
|
+
# parsed = tokenizer.parse_response(response2)
|
|
158
|
+
# print("\n--- Parsed Response ---")
|
|
159
|
+
# print("Role:", parsed.get("role"))
|
|
160
|
+
# print("Thinking:", parsed.get("thinking"))
|
|
161
|
+
# print("Content:", parsed.get("content"))
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
if __name__ == "__main__":
|
|
165
|
+
main()
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from handlers.cache import handle_cache_prefill
|
|
2
|
+
from handlers.capabilities import handle_capabilities
|
|
3
|
+
from handlers.chat import handle_chat
|
|
4
|
+
from handlers.completion import handle_completion
|
|
5
|
+
from handlers.format_test import handle_format_test
|
|
6
|
+
from handlers.tokenize import handle_tokenize
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
from backends.base import ModelBackend
|
|
7
|
+
from utils.prompt_builder import generate_merged_prompt, supports_chat_template
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def handle_cache_prefill(
|
|
11
|
+
backend: ModelBackend,
|
|
12
|
+
capabilities: dict,
|
|
13
|
+
cache_path: str,
|
|
14
|
+
messages: list,
|
|
15
|
+
base_cache_path: str | None = None,
|
|
16
|
+
trim_to_tokens: int | None = None,
|
|
17
|
+
prefix_offsets: list[int] | None = None,
|
|
18
|
+
prefix_hashes: list[str] | None = None,
|
|
19
|
+
tools: list | None = None,
|
|
20
|
+
reasoning_effort: str | None = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
tokenizer = backend.get_tokenizer()
|
|
23
|
+
|
|
24
|
+
extra_kwargs = {}
|
|
25
|
+
if tools is not None:
|
|
26
|
+
extra_kwargs["tools"] = tools
|
|
27
|
+
if reasoning_effort is not None:
|
|
28
|
+
extra_kwargs["reasoning_effort"] = reasoning_effort
|
|
29
|
+
if supports_chat_template(tokenizer):
|
|
30
|
+
try:
|
|
31
|
+
prompt = tokenizer.apply_chat_template(
|
|
32
|
+
messages,
|
|
33
|
+
add_generation_prompt=False,
|
|
34
|
+
tokenize=False,
|
|
35
|
+
**extra_kwargs,
|
|
36
|
+
)
|
|
37
|
+
except TypeError:
|
|
38
|
+
try:
|
|
39
|
+
fallback_kwargs = {}
|
|
40
|
+
if tools is not None:
|
|
41
|
+
fallback_kwargs["tools"] = tools
|
|
42
|
+
prompt = tokenizer.apply_chat_template(
|
|
43
|
+
messages,
|
|
44
|
+
add_generation_prompt=False,
|
|
45
|
+
tokenize=False,
|
|
46
|
+
**fallback_kwargs,
|
|
47
|
+
)
|
|
48
|
+
except TypeError:
|
|
49
|
+
try:
|
|
50
|
+
prompt = tokenizer.apply_chat_template(
|
|
51
|
+
messages,
|
|
52
|
+
add_generation_prompt=False,
|
|
53
|
+
tokenize=False,
|
|
54
|
+
)
|
|
55
|
+
except Exception:
|
|
56
|
+
prompt = generate_merged_prompt(messages, capabilities)
|
|
57
|
+
sys.stderr.write(
|
|
58
|
+
"--- cache_prefill: fallback to generate_merged_prompt\n"
|
|
59
|
+
)
|
|
60
|
+
except Exception:
|
|
61
|
+
prompt = generate_merged_prompt(messages, capabilities)
|
|
62
|
+
sys.stderr.write(
|
|
63
|
+
"--- cache_prefill: fallback to generate_merged_prompt\n"
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
prompt = generate_merged_prompt(messages, capabilities)
|
|
67
|
+
|
|
68
|
+
# Only show debug output if MLX_DEBUG environment variable is set
|
|
69
|
+
import os
|
|
70
|
+
if os.getenv('MLX_DEBUG'):
|
|
71
|
+
sys.stderr.write(f"--- cache_prefill {cache_path}\n")
|
|
72
|
+
result = backend.cache_prefill(
|
|
73
|
+
cache_path, prompt, base_cache_path,
|
|
74
|
+
trim_to_tokens=trim_to_tokens,
|
|
75
|
+
prefix_offsets=prefix_offsets,
|
|
76
|
+
prefix_hashes=prefix_hashes,
|
|
77
|
+
)
|
|
78
|
+
if prefix_offsets and prefix_hashes:
|
|
79
|
+
result["prefix_offsets"] = prefix_offsets
|
|
80
|
+
result["prefix_hashes"] = prefix_hashes
|
|
81
|
+
print(json.dumps(result), end="\0", flush=True)
|