@modular-prompt/driver 0.12.0 → 0.13.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. package/dist/anthropic/anthropic-driver.d.ts +38 -8
  2. package/dist/anthropic/anthropic-driver.d.ts.map +1 -1
  3. package/dist/anthropic/anthropic-driver.js +180 -164
  4. package/dist/anthropic/anthropic-driver.js.map +1 -1
  5. package/dist/cache-controller.d.ts +31 -0
  6. package/dist/cache-controller.d.ts.map +1 -0
  7. package/dist/cache-controller.js +2 -0
  8. package/dist/cache-controller.js.map +1 -0
  9. package/dist/cache-utils.d.ts +20 -0
  10. package/dist/cache-utils.d.ts.map +1 -0
  11. package/dist/cache-utils.js +71 -0
  12. package/dist/cache-utils.js.map +1 -0
  13. package/dist/content-utils.d.ts.map +1 -1
  14. package/dist/content-utils.js +20 -0
  15. package/dist/content-utils.js.map +1 -1
  16. package/dist/driver-registry/config-based-factory.d.ts.map +1 -1
  17. package/dist/driver-registry/config-based-factory.js +7 -0
  18. package/dist/driver-registry/config-based-factory.js.map +1 -1
  19. package/dist/driver-registry/factory-helper.d.ts.map +1 -1
  20. package/dist/driver-registry/factory-helper.js +7 -4
  21. package/dist/driver-registry/factory-helper.js.map +1 -1
  22. package/dist/driver-registry/types.d.ts +6 -0
  23. package/dist/driver-registry/types.d.ts.map +1 -1
  24. package/dist/formatter/converter.js +1 -1
  25. package/dist/formatter/converter.js.map +1 -1
  26. package/dist/google-genai/element-converter.d.ts +11 -0
  27. package/dist/google-genai/element-converter.d.ts.map +1 -0
  28. package/dist/google-genai/element-converter.js +126 -0
  29. package/dist/google-genai/element-converter.js.map +1 -0
  30. package/dist/google-genai/google-genai-cache-controller.d.ts +24 -0
  31. package/dist/google-genai/google-genai-cache-controller.d.ts.map +1 -0
  32. package/dist/google-genai/google-genai-cache-controller.js +127 -0
  33. package/dist/google-genai/google-genai-cache-controller.js.map +1 -0
  34. package/dist/google-genai/google-genai-driver.d.ts +5 -29
  35. package/dist/google-genai/google-genai-driver.d.ts.map +1 -1
  36. package/dist/google-genai/google-genai-driver.js +92 -255
  37. package/dist/google-genai/google-genai-driver.js.map +1 -1
  38. package/dist/index.d.ts +4 -0
  39. package/dist/index.d.ts.map +1 -1
  40. package/dist/index.js +3 -0
  41. package/dist/index.js.map +1 -1
  42. package/dist/mlx-ml/mlx-cache-controller.d.ts +65 -0
  43. package/dist/mlx-ml/mlx-cache-controller.d.ts.map +1 -0
  44. package/dist/mlx-ml/mlx-cache-controller.js +624 -0
  45. package/dist/mlx-ml/mlx-cache-controller.js.map +1 -0
  46. package/dist/mlx-ml/mlx-driver.d.ts +12 -7
  47. package/dist/mlx-ml/mlx-driver.d.ts.map +1 -1
  48. package/dist/mlx-ml/mlx-driver.js +192 -124
  49. package/dist/mlx-ml/mlx-driver.js.map +1 -1
  50. package/dist/mlx-ml/mlx-message-utils.d.ts +9 -0
  51. package/dist/mlx-ml/mlx-message-utils.d.ts.map +1 -0
  52. package/dist/mlx-ml/mlx-message-utils.js +71 -0
  53. package/dist/mlx-ml/mlx-message-utils.js.map +1 -0
  54. package/dist/mlx-ml/process/index.d.ts +7 -3
  55. package/dist/mlx-ml/process/index.d.ts.map +1 -1
  56. package/dist/mlx-ml/process/index.js +22 -7
  57. package/dist/mlx-ml/process/index.js.map +1 -1
  58. package/dist/mlx-ml/process/model-handlers.d.ts +4 -59
  59. package/dist/mlx-ml/process/model-handlers.d.ts.map +1 -1
  60. package/dist/mlx-ml/process/model-handlers.js +15 -14
  61. package/dist/mlx-ml/process/model-handlers.js.map +1 -1
  62. package/dist/mlx-ml/process/model-specific.d.ts +7 -0
  63. package/dist/mlx-ml/process/model-specific.d.ts.map +1 -1
  64. package/dist/mlx-ml/process/model-specific.js +3 -0
  65. package/dist/mlx-ml/process/model-specific.js.map +1 -1
  66. package/dist/mlx-ml/process/process-communication.d.ts +3 -0
  67. package/dist/mlx-ml/process/process-communication.d.ts.map +1 -1
  68. package/dist/mlx-ml/process/process-communication.js +13 -0
  69. package/dist/mlx-ml/process/process-communication.js.map +1 -1
  70. package/dist/mlx-ml/process/queue.d.ts +5 -2
  71. package/dist/mlx-ml/process/queue.d.ts.map +1 -1
  72. package/dist/mlx-ml/process/queue.js +101 -14
  73. package/dist/mlx-ml/process/queue.js.map +1 -1
  74. package/dist/mlx-ml/process/response-processor.d.ts +10 -0
  75. package/dist/mlx-ml/process/response-processor.d.ts.map +1 -1
  76. package/dist/mlx-ml/process/response-processor.js +23 -1
  77. package/dist/mlx-ml/process/response-processor.js.map +1 -1
  78. package/dist/mlx-ml/process/types.d.ts +50 -4
  79. package/dist/mlx-ml/process/types.d.ts.map +1 -1
  80. package/dist/mlx-ml/tool-call-parser/content-parsers.d.ts +9 -0
  81. package/dist/mlx-ml/tool-call-parser/content-parsers.d.ts.map +1 -0
  82. package/dist/mlx-ml/tool-call-parser/content-parsers.js +223 -0
  83. package/dist/mlx-ml/tool-call-parser/content-parsers.js.map +1 -0
  84. package/dist/mlx-ml/tool-call-parser/detector.d.ts +16 -0
  85. package/dist/mlx-ml/tool-call-parser/detector.d.ts.map +1 -0
  86. package/dist/mlx-ml/tool-call-parser/detector.js +58 -0
  87. package/dist/mlx-ml/tool-call-parser/detector.js.map +1 -0
  88. package/dist/mlx-ml/tool-call-parser/index.d.ts +7 -0
  89. package/dist/mlx-ml/tool-call-parser/index.d.ts.map +1 -0
  90. package/dist/mlx-ml/tool-call-parser/index.js +136 -0
  91. package/dist/mlx-ml/tool-call-parser/index.js.map +1 -0
  92. package/dist/mlx-ml/tool-call-parser/tool-formatter.d.ts +8 -0
  93. package/dist/mlx-ml/tool-call-parser/tool-formatter.d.ts.map +1 -0
  94. package/dist/mlx-ml/tool-call-parser/tool-formatter.js +88 -0
  95. package/dist/mlx-ml/tool-call-parser/tool-formatter.js.map +1 -0
  96. package/dist/mlx-ml/tool-call-parser/types.d.ts +18 -0
  97. package/dist/mlx-ml/tool-call-parser/types.d.ts.map +1 -0
  98. package/dist/mlx-ml/tool-call-parser/types.js +2 -0
  99. package/dist/mlx-ml/tool-call-parser/types.js.map +1 -0
  100. package/dist/mlx-ml/tool-call-parser/utils.d.ts +5 -0
  101. package/dist/mlx-ml/tool-call-parser/utils.d.ts.map +1 -0
  102. package/dist/mlx-ml/tool-call-parser/utils.js +77 -0
  103. package/dist/mlx-ml/tool-call-parser/utils.js.map +1 -0
  104. package/dist/types.d.ts +2 -0
  105. package/dist/types.d.ts.map +1 -1
  106. package/package.json +9 -4
  107. package/src/mlx-ml/python/__main__.py +41 -449
  108. package/src/mlx-ml/python/backends/__init__.py +3 -0
  109. package/src/mlx-ml/python/backends/base.py +84 -0
  110. package/src/mlx-ml/python/backends/mlx_lm.py +202 -0
  111. package/src/mlx-ml/python/backends/mlx_vlm.py +99 -0
  112. package/src/mlx-ml/python/handlers/__init__.py +6 -0
  113. package/src/mlx-ml/python/handlers/cache.py +81 -0
  114. package/src/mlx-ml/python/handlers/capabilities.py +6 -0
  115. package/src/mlx-ml/python/handlers/chat.py +221 -0
  116. package/src/mlx-ml/python/handlers/completion.py +36 -0
  117. package/src/mlx-ml/python/handlers/format_test.py +70 -0
  118. package/src/mlx-ml/python/handlers/tokenize.py +63 -0
  119. package/src/mlx-ml/python/pyproject.toml +13 -3
  120. package/src/mlx-ml/python/server.py +126 -0
  121. package/src/mlx-ml/python/tests/__init__.py +0 -0
  122. package/src/mlx-ml/python/utils/__init__.py +0 -0
  123. package/src/mlx-ml/python/utils/prompt_builder.py +54 -0
  124. package/src/mlx-ml/python/{token_utils.py → utils/token_utils.py} +9 -40
  125. package/src/mlx-ml/python/uv.lock +266 -41
  126. package/dist/mlx-ml/tool-call-parser.d.ts +0 -30
  127. package/dist/mlx-ml/tool-call-parser.d.ts.map +0 -1
  128. package/dist/mlx-ml/tool-call-parser.js +0 -623
  129. package/dist/mlx-ml/tool-call-parser.js.map +0 -1
  130. /package/src/mlx-ml/python/{example_basic.py → examples/example_basic.py} +0 -0
  131. /package/src/mlx-ml/python/{example_tool_call.py → examples/example_tool_call.py} +0 -0
  132. /package/src/mlx-ml/python/{chat_template_constraints.py → utils/chat_template_constraints.py} +0 -0
  133. /package/src/mlx-ml/python/{vlm_utils.py → utils/vlm_utils.py} +0 -0
@@ -0,0 +1,202 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import sys
6
+ from typing import Any, Iterator
7
+
8
+ from mlx_lm import load as mlx_lm_load
9
+ from mlx_lm import stream_generate as mlx_lm_stream_generate
10
+ from mlx_lm.models.cache import make_prompt_cache, save_prompt_cache, load_prompt_cache, trim_prompt_cache
11
+ from mlx_lm.sample_utils import make_sampler
12
+
13
+ from backends.base import ModelBackend
14
+ from utils.token_utils import is_eod_token
15
+
16
+
17
+ class MlxLmBackend(ModelBackend):
18
+ """`mlx_lm` backend for text-only models."""
19
+
20
+ def __init__(self) -> None:
21
+ self.model: Any | None = None
22
+ self.tokenizer: Any | None = None
23
+
24
+ def load(self, model_name: str) -> None:
25
+ self.model, self.tokenizer = mlx_lm_load(model_name)
26
+
27
+ def get_tokenizer(self) -> Any:
28
+ return self.tokenizer
29
+
30
+ def stream_generate(
31
+ self,
32
+ prompt: str | list[int],
33
+ options: dict,
34
+ images: list | None = None,
35
+ prompt_cache: list | None = None,
36
+ ) -> Iterator[Any]:
37
+ if self.model is None or self.tokenizer is None:
38
+ raise RuntimeError("Model is not loaded")
39
+
40
+ final_options = {"max_tokens": 1000, **options}
41
+ temperature = final_options.pop("temperature", 1.0)
42
+ top_p = final_options.pop("top_p", 0.0)
43
+ top_k = final_options.pop("top_k", 0)
44
+ final_options["sampler"] = make_sampler(
45
+ temp=temperature,
46
+ top_p=top_p,
47
+ top_k=top_k,
48
+ )
49
+
50
+ if prompt_cache is not None:
51
+ final_options["prompt_cache"] = prompt_cache
52
+
53
+ for response in mlx_lm_stream_generate(
54
+ self.model,
55
+ self.tokenizer,
56
+ prompt,
57
+ **final_options,
58
+ ):
59
+ if is_eod_token(response, self.tokenizer):
60
+ break
61
+ yield response
62
+
63
+ # get_cache_offset is inherited from ModelBackend base class
64
+
65
+ def _tokenize_prompt(self, prompt: str) -> list[int]:
66
+ """Tokenize a prompt string using the same logic as stream_generate."""
67
+ add_special = self.tokenizer.bos_token is None or not prompt.startswith(
68
+ self.tokenizer.bos_token
69
+ )
70
+ return self.tokenizer.encode(prompt, add_special_tokens=add_special)
71
+
72
+ @staticmethod
73
+ def _write_cache_meta(
74
+ cache_path: str,
75
+ token_count: int,
76
+ prefix_offsets: list[int] | None = None,
77
+ prefix_hashes: list[str] | None = None,
78
+ ) -> None:
79
+ meta_path = cache_path + '.meta.json'
80
+ try:
81
+ meta: dict[str, Any] = {"token_count": token_count}
82
+ if prefix_offsets and prefix_hashes:
83
+ meta["prefix_offsets"] = prefix_offsets
84
+ meta["prefix_hashes"] = prefix_hashes
85
+ with open(meta_path, 'w') as f:
86
+ json.dump(meta, f)
87
+ except Exception as e:
88
+ sys.stderr.write(f"Failed to write cache meta: {e}\n")
89
+
90
+ @staticmethod
91
+ def _read_cache_meta(cache_path: str) -> int | None:
92
+ meta_path = cache_path + '.meta.json'
93
+ try:
94
+ with open(meta_path) as f:
95
+ meta = json.load(f)
96
+ count = meta.get('token_count')
97
+ return int(count) if count is not None else None
98
+ except (FileNotFoundError, json.JSONDecodeError, ValueError, TypeError):
99
+ return None
100
+
101
+ def cache_prefill(
102
+ self,
103
+ cache_path: str,
104
+ prompt: str,
105
+ base_cache_path: str | None = None,
106
+ trim_to_tokens: int | None = None,
107
+ prefix_offsets: list[int] | None = None,
108
+ prefix_hashes: list[str] | None = None,
109
+ ) -> dict:
110
+ if self.model is None or self.tokenizer is None:
111
+ raise RuntimeError("Model is not loaded")
112
+
113
+ full_tokens = self._tokenize_prompt(prompt)
114
+ token_count = len(full_tokens)
115
+ effective_prompt: str | list[int] = prompt
116
+
117
+ if base_cache_path is not None:
118
+ try:
119
+ prompt_cache = load_prompt_cache(base_cache_path)
120
+
121
+ if trim_to_tokens is not None:
122
+ current_offset = self.get_cache_offset(prompt_cache)
123
+ if current_offset > trim_to_tokens:
124
+ trim_count = current_offset - trim_to_tokens
125
+ trim_prompt_cache(prompt_cache, trim_count)
126
+ if os.getenv('MLX_DEBUG'):
127
+ sys.stderr.write(
128
+ f"Trimmed base cache: {current_offset} → {trim_to_tokens} tokens\n"
129
+ )
130
+ cache_offset = trim_to_tokens
131
+ else:
132
+ cache_offset = current_offset
133
+ else:
134
+ meta_offset = self._read_cache_meta(base_cache_path)
135
+ if meta_offset is None:
136
+ # Legacy cache without meta file - create fresh cache
137
+ sys.stderr.write(
138
+ f"WARNING: Cache file exists but no .meta.json found at {base_cache_path}. "
139
+ "Creating fresh cache (may be from old implementation).\n"
140
+ )
141
+ cache_offset = 0
142
+ prompt_cache = make_prompt_cache(self.model)
143
+ else:
144
+ cache_offset = meta_offset
145
+
146
+ if cache_offset > 0 and prompt_cache is not None:
147
+ if cache_offset < token_count:
148
+ effective_prompt = full_tokens[cache_offset:]
149
+ if os.getenv('MLX_DEBUG'):
150
+ sys.stderr.write(
151
+ f"Incremental prefill from: {base_cache_path} "
152
+ f"(skip {cache_offset}/{token_count} tokens)\n"
153
+ )
154
+ else:
155
+ if os.getenv('MLX_DEBUG'):
156
+ sys.stderr.write(
157
+ f"Base cache covers entire prompt "
158
+ f"({cache_offset} >= {token_count}), saving as-is\n"
159
+ )
160
+ save_prompt_cache(cache_path, prompt_cache)
161
+ self._write_cache_meta(cache_path, token_count, prefix_offsets, prefix_hashes)
162
+ return {"cache_path": cache_path, "token_count": token_count}
163
+ else:
164
+ if os.getenv('MLX_DEBUG'):
165
+ sys.stderr.write(f"Incremental prefill from: {base_cache_path}\n")
166
+ except Exception as e:
167
+ sys.stderr.write(f"Base cache load failed, creating fresh: {e}\n")
168
+ prompt_cache = make_prompt_cache(self.model)
169
+ else:
170
+ prompt_cache = make_prompt_cache(self.model)
171
+
172
+ if os.getenv('MLX_DEBUG'):
173
+ sys.stderr.write(f"Prefill prompt: {token_count} tokens\n")
174
+
175
+ for _ in mlx_lm_stream_generate(
176
+ self.model, self.tokenizer, effective_prompt,
177
+ prompt_cache=prompt_cache, max_tokens=1,
178
+ ):
179
+ break
180
+
181
+ save_prompt_cache(cache_path, prompt_cache)
182
+ self._write_cache_meta(cache_path, token_count, prefix_offsets, prefix_hashes)
183
+ if os.getenv('MLX_DEBUG'):
184
+ sys.stderr.write(f"Cache created: {cache_path} ({token_count} tokens)\n")
185
+ return {"cache_path": cache_path, "token_count": token_count}
186
+
187
+ def load_cache_from_file(self, cache_path: str) -> list | None:
188
+ try:
189
+ return load_prompt_cache(cache_path)
190
+ except FileNotFoundError:
191
+ sys.stderr.write(f"Cache file not found: {cache_path}\n")
192
+ return None
193
+ except Exception as e:
194
+ sys.stderr.write(f"Failed to load cache: {e}\n")
195
+ return None
196
+
197
+ def supports_vision(self) -> bool:
198
+ return False
199
+
200
+ @property
201
+ def model_kind(self) -> str:
202
+ return "lm"
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from typing import Any, Iterator
5
+
6
+ from mlx_vlm import load as mlx_vlm_load
7
+ from mlx_vlm import stream_generate as mlx_vlm_stream_generate
8
+
9
+ from backends.base import ModelBackend
10
+ from utils.vlm_utils import load_and_resize_images
11
+
12
+
13
+ class MlxVlmBackend(ModelBackend):
14
+ """`mlx_vlm` backend for vision-language models."""
15
+
16
+ def __init__(self) -> None:
17
+ self.model: Any | None = None
18
+ self.processor: Any | None = None
19
+ self.drafter: Any | None = None
20
+ self.drafter_kind: str | None = None
21
+ self.draft_block_size: int | None = None
22
+
23
+ def load(self, model_name: str) -> None:
24
+ self.model, self.processor = mlx_vlm_load(model_name)
25
+
26
+ def load_drafter(self, drafter_model: str) -> None:
27
+ from mlx_vlm.speculative.drafters import load_drafter
28
+ self.drafter, self.drafter_kind = load_drafter(drafter_model)
29
+ sys.stderr.write(f"Drafter loaded: {drafter_model} (kind={self.drafter_kind})\n")
30
+
31
+ def has_drafter(self) -> bool:
32
+ return self.drafter is not None
33
+
34
+ def get_tokenizer(self) -> Any:
35
+ return self.processor
36
+
37
+ def stream_generate(
38
+ self, prompt: str | list[int], options: dict, images: list | None = None,
39
+ prompt_cache: list | None = None,
40
+ ) -> Iterator[Any]:
41
+ if self.model is None or self.processor is None:
42
+ raise RuntimeError("Model is not loaded")
43
+
44
+ final_options = dict(options)
45
+ temperature = final_options.pop("temperature", 1.0)
46
+ max_tokens = final_options.pop("max_tokens", 1000)
47
+ top_p = final_options.pop("top_p", 0.0)
48
+ top_k = final_options.pop("top_k", 0)
49
+
50
+ processed_images = None
51
+ if images:
52
+ max_image_size = final_options.pop("max_image_size", 768)
53
+ processed_images = load_and_resize_images(images, max_image_size)
54
+
55
+ draft_kwargs = {}
56
+ if self.drafter:
57
+ draft_kwargs["draft_model"] = self.drafter
58
+ draft_kwargs["draft_kind"] = self.drafter_kind
59
+ if self.draft_block_size is not None:
60
+ draft_kwargs["draft_block_size"] = self.draft_block_size
61
+
62
+ # Generate and collect tokens
63
+ token_count = 0
64
+ for result in mlx_vlm_stream_generate(
65
+ self.model,
66
+ self.processor,
67
+ prompt,
68
+ image=processed_images,
69
+ max_tokens=max_tokens,
70
+ temperature=temperature,
71
+ top_p=top_p,
72
+ top_k=top_k,
73
+ **draft_kwargs,
74
+ ):
75
+ token_count += 1
76
+ yield result
77
+
78
+ # Output speculative decoding stats if drafter was used
79
+ if self.drafter and hasattr(self.drafter, 'accept_lens'):
80
+ accept_lens = self.drafter.accept_lens
81
+ if accept_lens:
82
+ avg_accepted = sum(accept_lens) / len(accept_lens)
83
+ # Show stats unless MLX_NO_STATS environment variable is set
84
+ import os
85
+ if not os.getenv('MLX_NO_STATS'):
86
+ sys.stderr.write(f"\n[Speculative Decoding Stats]\n")
87
+ sys.stderr.write(f" Rounds: {len(accept_lens)}\n")
88
+ sys.stderr.write(f" Average accepted tokens/round: {avg_accepted:.2f}\n")
89
+ sys.stderr.write(f" Total tokens generated: {token_count}\n")
90
+ sys.stderr.write(f" Speedup factor: {avg_accepted:.2f}x (theoretical)\n")
91
+ # Clear for next generation
92
+ self.drafter.accept_lens = []
93
+
94
+ def supports_vision(self) -> bool:
95
+ return True
96
+
97
+ @property
98
+ def model_kind(self) -> str:
99
+ return "vlm"
@@ -0,0 +1,6 @@
1
+ from handlers.cache import handle_cache_prefill
2
+ from handlers.capabilities import handle_capabilities
3
+ from handlers.chat import handle_chat
4
+ from handlers.completion import handle_completion
5
+ from handlers.format_test import handle_format_test
6
+ from handlers.tokenize import handle_tokenize
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import sys
5
+
6
+ from backends.base import ModelBackend
7
+ from utils.prompt_builder import generate_merged_prompt, supports_chat_template
8
+
9
+
10
+ def handle_cache_prefill(
11
+ backend: ModelBackend,
12
+ capabilities: dict,
13
+ cache_path: str,
14
+ messages: list,
15
+ base_cache_path: str | None = None,
16
+ trim_to_tokens: int | None = None,
17
+ prefix_offsets: list[int] | None = None,
18
+ prefix_hashes: list[str] | None = None,
19
+ tools: list | None = None,
20
+ reasoning_effort: str | None = None,
21
+ ) -> None:
22
+ tokenizer = backend.get_tokenizer()
23
+
24
+ extra_kwargs = {}
25
+ if tools is not None:
26
+ extra_kwargs["tools"] = tools
27
+ if reasoning_effort is not None:
28
+ extra_kwargs["reasoning_effort"] = reasoning_effort
29
+ if supports_chat_template(tokenizer):
30
+ try:
31
+ prompt = tokenizer.apply_chat_template(
32
+ messages,
33
+ add_generation_prompt=False,
34
+ tokenize=False,
35
+ **extra_kwargs,
36
+ )
37
+ except TypeError:
38
+ try:
39
+ fallback_kwargs = {}
40
+ if tools is not None:
41
+ fallback_kwargs["tools"] = tools
42
+ prompt = tokenizer.apply_chat_template(
43
+ messages,
44
+ add_generation_prompt=False,
45
+ tokenize=False,
46
+ **fallback_kwargs,
47
+ )
48
+ except TypeError:
49
+ try:
50
+ prompt = tokenizer.apply_chat_template(
51
+ messages,
52
+ add_generation_prompt=False,
53
+ tokenize=False,
54
+ )
55
+ except Exception:
56
+ prompt = generate_merged_prompt(messages, capabilities)
57
+ sys.stderr.write(
58
+ "--- cache_prefill: fallback to generate_merged_prompt\n"
59
+ )
60
+ except Exception:
61
+ prompt = generate_merged_prompt(messages, capabilities)
62
+ sys.stderr.write(
63
+ "--- cache_prefill: fallback to generate_merged_prompt\n"
64
+ )
65
+ else:
66
+ prompt = generate_merged_prompt(messages, capabilities)
67
+
68
+ # Only show debug output if MLX_DEBUG environment variable is set
69
+ import os
70
+ if os.getenv('MLX_DEBUG'):
71
+ sys.stderr.write(f"--- cache_prefill {cache_path}\n")
72
+ result = backend.cache_prefill(
73
+ cache_path, prompt, base_cache_path,
74
+ trim_to_tokens=trim_to_tokens,
75
+ prefix_offsets=prefix_offsets,
76
+ prefix_hashes=prefix_hashes,
77
+ )
78
+ if prefix_offsets and prefix_hashes:
79
+ result["prefix_offsets"] = prefix_offsets
80
+ result["prefix_hashes"] = prefix_hashes
81
+ print(json.dumps(result), end="\0", flush=True)
@@ -0,0 +1,6 @@
1
+ import json
2
+
3
+
4
+ def handle_capabilities(capabilities: dict) -> None:
5
+ """capabilities API の処理。JSON出力してnull文字で終端"""
6
+ print(json.dumps(capabilities), end="\0", flush=True)
@@ -0,0 +1,221 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ import sys
6
+
7
+ from backends.base import ModelBackend
8
+ from mlx_lm.models.cache import trim_prompt_cache
9
+ from utils.prompt_builder import generate_merged_prompt, supports_chat_template
10
+
11
+
12
+ def _read_cache_token_count(cache_path: str) -> int | None:
13
+ """Read token count from the sidecar .meta.json file."""
14
+ meta_path = cache_path + '.meta.json'
15
+ try:
16
+ with open(meta_path) as f:
17
+ meta = json.load(f)
18
+ count = meta.get('token_count')
19
+ return int(count) if count is not None else None
20
+ except (FileNotFoundError, json.JSONDecodeError, ValueError, TypeError):
21
+ return None
22
+
23
+
24
+ def _stream_to_stdout(
25
+ backend: ModelBackend,
26
+ prompt: str | list[int],
27
+ options: dict,
28
+ images: list | None = None,
29
+ primer: str | None = None,
30
+ prompt_cache: list | None = None,
31
+ ) -> None:
32
+ if primer is not None:
33
+ print(primer, end="", flush=True)
34
+
35
+ last_response = None
36
+ for response in backend.stream_generate(prompt, options, images, prompt_cache=prompt_cache):
37
+ print(response.text.replace("\0", "").replace("\x1e", ""), end="", flush=True)
38
+ last_response = response
39
+
40
+ meta: dict = {}
41
+ if last_response is not None:
42
+ if hasattr(last_response, "prompt_tokens"):
43
+ meta["prompt_tokens"] = last_response.prompt_tokens
44
+ if hasattr(last_response, "generation_tokens"):
45
+ meta["generation_tokens"] = last_response.generation_tokens
46
+
47
+ if meta:
48
+ print(f"\x1e__META__:{json.dumps(meta)}", end="\0", flush=True)
49
+ else:
50
+ print("", end="\0", flush=True)
51
+
52
+
53
+ def handle_chat(
54
+ backend: ModelBackend,
55
+ capabilities: dict,
56
+ messages: list,
57
+ primer: str | None = None,
58
+ options: dict | None = None,
59
+ tools: list | None = None,
60
+ images: list | None = None,
61
+ max_image_size: int = 768,
62
+ reasoning_effort: str | None = None,
63
+ cache_path: str | None = None,
64
+ cache_trim_tokens: int | None = None,
65
+ ) -> None:
66
+ """chat API の処理"""
67
+ if options is None:
68
+ options = {}
69
+
70
+ tokenizer = backend.get_tokenizer()
71
+
72
+ if backend.supports_vision():
73
+ add_generation_prompt = True
74
+ fmt_messages = list(messages)
75
+ if primer is not None:
76
+ fmt_messages.append({"role": "assistant", "content": primer})
77
+ add_generation_prompt = False
78
+
79
+ try:
80
+ prompt = tokenizer.apply_chat_template(
81
+ fmt_messages,
82
+ tools=tools,
83
+ add_generation_prompt=add_generation_prompt,
84
+ tokenize=False,
85
+ )
86
+ except TypeError:
87
+ prompt = tokenizer.apply_chat_template(
88
+ fmt_messages,
89
+ add_generation_prompt=add_generation_prompt,
90
+ tokenize=False,
91
+ )
92
+
93
+ if primer is not None:
94
+ prompt = primer.join(prompt.split(primer)[0:-1]) + primer
95
+
96
+ display_prompt = re.sub(r'(<\|image_pad\|>)+', '<|image_pad|>...', prompt)
97
+ sys.stderr.write(f"--- vlm prompt (images: {len(images) if images else 0}, max_size: {max_image_size})\n{display_prompt}\n")
98
+
99
+ final_options = dict(options)
100
+ final_options["max_image_size"] = max_image_size
101
+ _stream_to_stdout(
102
+ backend,
103
+ prompt,
104
+ final_options,
105
+ images=images,
106
+ primer=primer,
107
+ )
108
+ return
109
+
110
+ prompt_cache = backend.load_cache_from_file(cache_path) if cache_path else None
111
+ cache_tokens = 0
112
+ if prompt_cache is not None:
113
+ if cache_trim_tokens is not None:
114
+ current_offset = backend.get_cache_offset(prompt_cache)
115
+ if current_offset > cache_trim_tokens:
116
+ trim_prompt_cache(prompt_cache, current_offset - cache_trim_tokens)
117
+ sys.stderr.write(
118
+ f"KV cache trimmed: {current_offset} → {cache_trim_tokens} tokens\n"
119
+ )
120
+ cache_tokens = cache_trim_tokens
121
+ else:
122
+ cache_tokens = current_offset
123
+ else:
124
+ meta_count = _read_cache_token_count(cache_path) if cache_path else None
125
+ if meta_count is not None:
126
+ cache_tokens = meta_count
127
+ else:
128
+ # Legacy cache without meta file - skip it for safety
129
+ sys.stderr.write(
130
+ f"WARNING: Cache file exists but no .meta.json found at {cache_path}. "
131
+ "Ignoring cache for safety (may be from old implementation).\n"
132
+ )
133
+ prompt_cache = None
134
+ cache_tokens = 0
135
+ if prompt_cache is not None:
136
+ sys.stderr.write(
137
+ f"KV cache loaded: {len(prompt_cache)} layers, {cache_tokens} cached tokens\n"
138
+ )
139
+ elif cache_path:
140
+ sys.stderr.write(f"KV cache load FAILED: {cache_path}\n")
141
+
142
+ if not supports_chat_template(tokenizer):
143
+ prompt = generate_merged_prompt(messages, capabilities)
144
+ if prompt_cache is not None:
145
+ sys.stderr.write("KV cache ignored: model does not support chat template\n")
146
+ _stream_to_stdout(backend, prompt, options, primer=primer)
147
+ return
148
+
149
+ add_generation_prompt = True
150
+ fmt_messages = list(messages)
151
+ if primer is not None:
152
+ fmt_messages.append({"role": "assistant", "content": primer})
153
+ add_generation_prompt = False
154
+
155
+ extra_kwargs = {}
156
+ if tools is not None:
157
+ extra_kwargs["tools"] = tools
158
+ if reasoning_effort is not None:
159
+ extra_kwargs["reasoning_effort"] = reasoning_effort
160
+
161
+ trust_remote_code = options.get("trust_remote_code")
162
+ if trust_remote_code is not None:
163
+ extra_kwargs["trust_remote_code"] = trust_remote_code
164
+
165
+ try:
166
+ prompt = tokenizer.apply_chat_template(
167
+ fmt_messages,
168
+ add_generation_prompt=add_generation_prompt,
169
+ tokenize=False,
170
+ **extra_kwargs,
171
+ )
172
+ except TypeError:
173
+ try:
174
+ fallback_kwargs = {}
175
+ if tools is not None:
176
+ fallback_kwargs["tools"] = tools
177
+ prompt = tokenizer.apply_chat_template(
178
+ fmt_messages,
179
+ add_generation_prompt=add_generation_prompt,
180
+ tokenize=False,
181
+ **fallback_kwargs,
182
+ )
183
+ except TypeError:
184
+ prompt = tokenizer.apply_chat_template(
185
+ fmt_messages,
186
+ add_generation_prompt=add_generation_prompt,
187
+ tokenize=False,
188
+ )
189
+
190
+ if primer is not None:
191
+ prompt = primer.join(prompt.split(primer)[0:-1]) + primer
192
+
193
+ if isinstance(prompt, list):
194
+ sys.stderr.write(f"--- prompt: len={len(prompt)}\n")
195
+ else:
196
+ sys.stderr.write(f"--- prompt\n{prompt}\n")
197
+
198
+ final_options = dict(options)
199
+ final_options.pop("trust_remote_code", None)
200
+
201
+ effective_prompt = prompt
202
+ if prompt_cache is not None and cache_tokens > 0 and isinstance(prompt, str):
203
+ add_special = tokenizer.bos_token is None or not prompt.startswith(
204
+ tokenizer.bos_token
205
+ )
206
+ full_tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
207
+
208
+ if cache_tokens < len(full_tokens):
209
+ effective_prompt = full_tokens[cache_tokens:]
210
+ sys.stderr.write(
211
+ f"Prefilled {cache_tokens}/{len(full_tokens)} tokens, "
212
+ f"generating from {len(effective_prompt)} remaining\n"
213
+ )
214
+ else:
215
+ sys.stderr.write(
216
+ f"Prefill offset {cache_tokens} >= prompt {len(full_tokens)}, "
217
+ f"ignoring prefill state\n"
218
+ )
219
+ prompt_cache = None
220
+
221
+ _stream_to_stdout(backend, effective_prompt, final_options, primer=primer, prompt_cache=prompt_cache)
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import sys
6
+
7
+ from backends.base import ModelBackend
8
+
9
+
10
+ def handle_completion(
11
+ backend: ModelBackend,
12
+ prompt: str | list[int],
13
+ options: dict | None = None,
14
+ images: list | None = None,
15
+ max_image_size: int = 768,
16
+ ) -> None:
17
+ """completion API の処理"""
18
+ if options is None:
19
+ options = {}
20
+
21
+ final_options = dict(options)
22
+ if images:
23
+ final_options["max_image_size"] = max_image_size
24
+ if os.getenv('MLX_DEBUG'):
25
+ display_prompt = re.sub(r'(<\|image_pad\|>)+', '<|image_pad|>...', prompt)
26
+ sys.stderr.write(f"--- vlm completion (images: {len(images)}, max_size: {max_image_size})\n{display_prompt}\n")
27
+ elif os.getenv('MLX_DEBUG'):
28
+ if isinstance(prompt, list):
29
+ sys.stderr.write(f"--- prompt: len={len(prompt)}\n")
30
+ else:
31
+ sys.stderr.write(f"--- prompt\n{prompt}\n")
32
+
33
+ for response in backend.stream_generate(prompt, final_options, images):
34
+ print(response.text.replace("\0", ""), end="", flush=True)
35
+
36
+ print("\n", end="\0", flush=True)