@modular-prompt/driver 0.11.15 → 0.13.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. package/README.md +25 -0
  2. package/dist/anthropic/anthropic-driver.d.ts +38 -8
  3. package/dist/anthropic/anthropic-driver.d.ts.map +1 -1
  4. package/dist/anthropic/anthropic-driver.js +180 -164
  5. package/dist/anthropic/anthropic-driver.js.map +1 -1
  6. package/dist/cache-controller.d.ts +28 -0
  7. package/dist/cache-controller.d.ts.map +1 -0
  8. package/dist/cache-controller.js +2 -0
  9. package/dist/cache-controller.js.map +1 -0
  10. package/dist/cache-utils.d.ts +20 -0
  11. package/dist/cache-utils.d.ts.map +1 -0
  12. package/dist/cache-utils.js +71 -0
  13. package/dist/cache-utils.js.map +1 -0
  14. package/dist/content-utils.d.ts +9 -0
  15. package/dist/content-utils.d.ts.map +1 -1
  16. package/dist/content-utils.js +47 -0
  17. package/dist/content-utils.js.map +1 -1
  18. package/dist/driver-registry/config-based-factory.d.ts.map +1 -1
  19. package/dist/driver-registry/config-based-factory.js +7 -0
  20. package/dist/driver-registry/config-based-factory.js.map +1 -1
  21. package/dist/driver-registry/factory-helper.d.ts.map +1 -1
  22. package/dist/driver-registry/factory-helper.js +7 -4
  23. package/dist/driver-registry/factory-helper.js.map +1 -1
  24. package/dist/driver-registry/types.d.ts +6 -0
  25. package/dist/driver-registry/types.d.ts.map +1 -1
  26. package/dist/formatter/converter.js +1 -1
  27. package/dist/formatter/converter.js.map +1 -1
  28. package/dist/google-genai/element-converter.d.ts +11 -0
  29. package/dist/google-genai/element-converter.d.ts.map +1 -0
  30. package/dist/google-genai/element-converter.js +126 -0
  31. package/dist/google-genai/element-converter.js.map +1 -0
  32. package/dist/google-genai/google-genai-cache-controller.d.ts +24 -0
  33. package/dist/google-genai/google-genai-cache-controller.d.ts.map +1 -0
  34. package/dist/google-genai/google-genai-cache-controller.js +127 -0
  35. package/dist/google-genai/google-genai-cache-controller.js.map +1 -0
  36. package/dist/google-genai/google-genai-driver.d.ts +5 -29
  37. package/dist/google-genai/google-genai-driver.d.ts.map +1 -1
  38. package/dist/google-genai/google-genai-driver.js +92 -255
  39. package/dist/google-genai/google-genai-driver.js.map +1 -1
  40. package/dist/index.d.ts +4 -0
  41. package/dist/index.d.ts.map +1 -1
  42. package/dist/index.js +3 -0
  43. package/dist/index.js.map +1 -1
  44. package/dist/mlx-ml/mlx-cache-controller.d.ts +66 -0
  45. package/dist/mlx-ml/mlx-cache-controller.d.ts.map +1 -0
  46. package/dist/mlx-ml/mlx-cache-controller.js +600 -0
  47. package/dist/mlx-ml/mlx-cache-controller.js.map +1 -0
  48. package/dist/mlx-ml/mlx-driver.d.ts +13 -8
  49. package/dist/mlx-ml/mlx-driver.d.ts.map +1 -1
  50. package/dist/mlx-ml/mlx-driver.js +202 -143
  51. package/dist/mlx-ml/mlx-driver.js.map +1 -1
  52. package/dist/mlx-ml/mlx-message-utils.d.ts +9 -0
  53. package/dist/mlx-ml/mlx-message-utils.d.ts.map +1 -0
  54. package/dist/mlx-ml/mlx-message-utils.js +71 -0
  55. package/dist/mlx-ml/mlx-message-utils.js.map +1 -0
  56. package/dist/mlx-ml/process/harmony-parser.d.ts +3 -0
  57. package/dist/mlx-ml/process/harmony-parser.d.ts.map +1 -0
  58. package/dist/mlx-ml/process/harmony-parser.js +175 -0
  59. package/dist/mlx-ml/process/harmony-parser.js.map +1 -0
  60. package/dist/mlx-ml/process/index.d.ts +7 -3
  61. package/dist/mlx-ml/process/index.d.ts.map +1 -1
  62. package/dist/mlx-ml/process/index.js +22 -7
  63. package/dist/mlx-ml/process/index.js.map +1 -1
  64. package/dist/mlx-ml/process/model-handlers.d.ts +11 -58
  65. package/dist/mlx-ml/process/model-handlers.d.ts.map +1 -1
  66. package/dist/mlx-ml/process/model-handlers.js +29 -11
  67. package/dist/mlx-ml/process/model-handlers.js.map +1 -1
  68. package/dist/mlx-ml/process/model-specific.d.ts +7 -0
  69. package/dist/mlx-ml/process/model-specific.d.ts.map +1 -1
  70. package/dist/mlx-ml/process/model-specific.js +3 -0
  71. package/dist/mlx-ml/process/model-specific.js.map +1 -1
  72. package/dist/mlx-ml/process/parameter-validator.d.ts.map +1 -1
  73. package/dist/mlx-ml/process/parameter-validator.js +10 -3
  74. package/dist/mlx-ml/process/parameter-validator.js.map +1 -1
  75. package/dist/mlx-ml/process/process-communication.d.ts +3 -0
  76. package/dist/mlx-ml/process/process-communication.d.ts.map +1 -1
  77. package/dist/mlx-ml/process/process-communication.js +13 -0
  78. package/dist/mlx-ml/process/process-communication.js.map +1 -1
  79. package/dist/mlx-ml/process/queue.d.ts +5 -2
  80. package/dist/mlx-ml/process/queue.d.ts.map +1 -1
  81. package/dist/mlx-ml/process/queue.js +103 -15
  82. package/dist/mlx-ml/process/queue.js.map +1 -1
  83. package/dist/mlx-ml/process/response-processor.d.ts +18 -0
  84. package/dist/mlx-ml/process/response-processor.d.ts.map +1 -0
  85. package/dist/mlx-ml/process/response-processor.js +24 -0
  86. package/dist/mlx-ml/process/response-processor.js.map +1 -0
  87. package/dist/mlx-ml/process/types.d.ts +51 -4
  88. package/dist/mlx-ml/process/types.d.ts.map +1 -1
  89. package/dist/mlx-ml/tool-call-parser.d.ts.map +1 -1
  90. package/dist/mlx-ml/tool-call-parser.js +44 -68
  91. package/dist/mlx-ml/tool-call-parser.js.map +1 -1
  92. package/dist/mlx-ml/types.d.ts +1 -0
  93. package/dist/mlx-ml/types.d.ts.map +1 -1
  94. package/dist/openai/openai-driver.d.ts +0 -2
  95. package/dist/openai/openai-driver.d.ts.map +1 -1
  96. package/dist/openai/openai-driver.js.map +1 -1
  97. package/dist/types.d.ts +9 -0
  98. package/dist/types.d.ts.map +1 -1
  99. package/package.json +7 -4
  100. package/src/mlx-ml/python/__main__.py +41 -425
  101. package/src/mlx-ml/python/backends/__init__.py +3 -0
  102. package/src/mlx-ml/python/backends/base.py +84 -0
  103. package/src/mlx-ml/python/backends/mlx_lm.py +202 -0
  104. package/src/mlx-ml/python/backends/mlx_vlm.py +99 -0
  105. package/src/mlx-ml/python/examples/example_basic.py +93 -0
  106. package/src/mlx-ml/python/examples/example_tool_call.py +165 -0
  107. package/src/mlx-ml/python/handlers/__init__.py +6 -0
  108. package/src/mlx-ml/python/handlers/cache.py +81 -0
  109. package/src/mlx-ml/python/handlers/capabilities.py +6 -0
  110. package/src/mlx-ml/python/handlers/chat.py +221 -0
  111. package/src/mlx-ml/python/handlers/completion.py +36 -0
  112. package/src/mlx-ml/python/handlers/format_test.py +70 -0
  113. package/src/mlx-ml/python/handlers/tokenize.py +63 -0
  114. package/src/mlx-ml/python/pyproject.toml +15 -5
  115. package/src/mlx-ml/python/server.py +126 -0
  116. package/src/mlx-ml/python/tests/__init__.py +0 -0
  117. package/src/mlx-ml/python/utils/__init__.py +0 -0
  118. package/src/mlx-ml/python/utils/prompt_builder.py +54 -0
  119. package/src/mlx-ml/python/{token_utils.py → utils/token_utils.py} +13 -5
  120. package/src/mlx-ml/python/uv.lock +299 -57
  121. /package/src/mlx-ml/python/{chat_template_constraints.py → utils/chat_template_constraints.py} +0 -0
  122. /package/src/mlx-ml/python/{vlm_utils.py → utils/vlm_utils.py} +0 -0
@@ -0,0 +1,202 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import sys
6
+ from typing import Any, Iterator
7
+
8
+ from mlx_lm import load as mlx_lm_load
9
+ from mlx_lm import stream_generate as mlx_lm_stream_generate
10
+ from mlx_lm.models.cache import make_prompt_cache, save_prompt_cache, load_prompt_cache, trim_prompt_cache
11
+ from mlx_lm.sample_utils import make_sampler
12
+
13
+ from backends.base import ModelBackend
14
+ from utils.token_utils import is_eod_token
15
+
16
+
17
+ class MlxLmBackend(ModelBackend):
18
+ """`mlx_lm` backend for text-only models."""
19
+
20
+ def __init__(self) -> None:
21
+ self.model: Any | None = None
22
+ self.tokenizer: Any | None = None
23
+
24
+ def load(self, model_name: str) -> None:
25
+ self.model, self.tokenizer = mlx_lm_load(model_name)
26
+
27
+ def get_tokenizer(self) -> Any:
28
+ return self.tokenizer
29
+
30
+ def stream_generate(
31
+ self,
32
+ prompt: str | list[int],
33
+ options: dict,
34
+ images: list | None = None,
35
+ prompt_cache: list | None = None,
36
+ ) -> Iterator[Any]:
37
+ if self.model is None or self.tokenizer is None:
38
+ raise RuntimeError("Model is not loaded")
39
+
40
+ final_options = {"max_tokens": 1000, **options}
41
+ temperature = final_options.pop("temperature", 1.0)
42
+ top_p = final_options.pop("top_p", 0.0)
43
+ top_k = final_options.pop("top_k", 0)
44
+ final_options["sampler"] = make_sampler(
45
+ temp=temperature,
46
+ top_p=top_p,
47
+ top_k=top_k,
48
+ )
49
+
50
+ if prompt_cache is not None:
51
+ final_options["prompt_cache"] = prompt_cache
52
+
53
+ for response in mlx_lm_stream_generate(
54
+ self.model,
55
+ self.tokenizer,
56
+ prompt,
57
+ **final_options,
58
+ ):
59
+ if is_eod_token(response, self.tokenizer):
60
+ break
61
+ yield response
62
+
63
+ # get_cache_offset is inherited from ModelBackend base class
64
+
65
+ def _tokenize_prompt(self, prompt: str) -> list[int]:
66
+ """Tokenize a prompt string using the same logic as stream_generate."""
67
+ add_special = self.tokenizer.bos_token is None or not prompt.startswith(
68
+ self.tokenizer.bos_token
69
+ )
70
+ return self.tokenizer.encode(prompt, add_special_tokens=add_special)
71
+
72
+ @staticmethod
73
+ def _write_cache_meta(
74
+ cache_path: str,
75
+ token_count: int,
76
+ prefix_offsets: list[int] | None = None,
77
+ prefix_hashes: list[str] | None = None,
78
+ ) -> None:
79
+ meta_path = cache_path + '.meta.json'
80
+ try:
81
+ meta: dict[str, Any] = {"token_count": token_count}
82
+ if prefix_offsets and prefix_hashes:
83
+ meta["prefix_offsets"] = prefix_offsets
84
+ meta["prefix_hashes"] = prefix_hashes
85
+ with open(meta_path, 'w') as f:
86
+ json.dump(meta, f)
87
+ except Exception as e:
88
+ sys.stderr.write(f"Failed to write cache meta: {e}\n")
89
+
90
+ @staticmethod
91
+ def _read_cache_meta(cache_path: str) -> int | None:
92
+ meta_path = cache_path + '.meta.json'
93
+ try:
94
+ with open(meta_path) as f:
95
+ meta = json.load(f)
96
+ count = meta.get('token_count')
97
+ return int(count) if count is not None else None
98
+ except (FileNotFoundError, json.JSONDecodeError, ValueError, TypeError):
99
+ return None
100
+
101
+ def cache_prefill(
102
+ self,
103
+ cache_path: str,
104
+ prompt: str,
105
+ base_cache_path: str | None = None,
106
+ trim_to_tokens: int | None = None,
107
+ prefix_offsets: list[int] | None = None,
108
+ prefix_hashes: list[str] | None = None,
109
+ ) -> dict:
110
+ if self.model is None or self.tokenizer is None:
111
+ raise RuntimeError("Model is not loaded")
112
+
113
+ full_tokens = self._tokenize_prompt(prompt)
114
+ token_count = len(full_tokens)
115
+ effective_prompt: str | list[int] = prompt
116
+
117
+ if base_cache_path is not None:
118
+ try:
119
+ prompt_cache = load_prompt_cache(base_cache_path)
120
+
121
+ if trim_to_tokens is not None:
122
+ current_offset = self.get_cache_offset(prompt_cache)
123
+ if current_offset > trim_to_tokens:
124
+ trim_count = current_offset - trim_to_tokens
125
+ trim_prompt_cache(prompt_cache, trim_count)
126
+ if os.getenv('MLX_DEBUG'):
127
+ sys.stderr.write(
128
+ f"Trimmed base cache: {current_offset} → {trim_to_tokens} tokens\n"
129
+ )
130
+ cache_offset = trim_to_tokens
131
+ else:
132
+ cache_offset = current_offset
133
+ else:
134
+ meta_offset = self._read_cache_meta(base_cache_path)
135
+ if meta_offset is None:
136
+ # Legacy cache without meta file - create fresh cache
137
+ sys.stderr.write(
138
+ f"WARNING: Cache file exists but no .meta.json found at {base_cache_path}. "
139
+ "Creating fresh cache (may be from old implementation).\n"
140
+ )
141
+ cache_offset = 0
142
+ prompt_cache = make_prompt_cache(self.model)
143
+ else:
144
+ cache_offset = meta_offset
145
+
146
+ if cache_offset > 0 and prompt_cache is not None:
147
+ if cache_offset < token_count:
148
+ effective_prompt = full_tokens[cache_offset:]
149
+ if os.getenv('MLX_DEBUG'):
150
+ sys.stderr.write(
151
+ f"Incremental prefill from: {base_cache_path} "
152
+ f"(skip {cache_offset}/{token_count} tokens)\n"
153
+ )
154
+ else:
155
+ if os.getenv('MLX_DEBUG'):
156
+ sys.stderr.write(
157
+ f"Base cache covers entire prompt "
158
+ f"({cache_offset} >= {token_count}), saving as-is\n"
159
+ )
160
+ save_prompt_cache(cache_path, prompt_cache)
161
+ self._write_cache_meta(cache_path, token_count, prefix_offsets, prefix_hashes)
162
+ return {"cache_path": cache_path, "token_count": token_count}
163
+ else:
164
+ if os.getenv('MLX_DEBUG'):
165
+ sys.stderr.write(f"Incremental prefill from: {base_cache_path}\n")
166
+ except Exception as e:
167
+ sys.stderr.write(f"Base cache load failed, creating fresh: {e}\n")
168
+ prompt_cache = make_prompt_cache(self.model)
169
+ else:
170
+ prompt_cache = make_prompt_cache(self.model)
171
+
172
+ if os.getenv('MLX_DEBUG'):
173
+ sys.stderr.write(f"Prefill prompt: {token_count} tokens\n")
174
+
175
+ for _ in mlx_lm_stream_generate(
176
+ self.model, self.tokenizer, effective_prompt,
177
+ prompt_cache=prompt_cache, max_tokens=1,
178
+ ):
179
+ break
180
+
181
+ save_prompt_cache(cache_path, prompt_cache)
182
+ self._write_cache_meta(cache_path, token_count, prefix_offsets, prefix_hashes)
183
+ if os.getenv('MLX_DEBUG'):
184
+ sys.stderr.write(f"Cache created: {cache_path} ({token_count} tokens)\n")
185
+ return {"cache_path": cache_path, "token_count": token_count}
186
+
187
+ def load_cache_from_file(self, cache_path: str) -> list | None:
188
+ try:
189
+ return load_prompt_cache(cache_path)
190
+ except FileNotFoundError:
191
+ sys.stderr.write(f"Cache file not found: {cache_path}\n")
192
+ return None
193
+ except Exception as e:
194
+ sys.stderr.write(f"Failed to load cache: {e}\n")
195
+ return None
196
+
197
+ def supports_vision(self) -> bool:
198
+ return False
199
+
200
+ @property
201
+ def model_kind(self) -> str:
202
+ return "lm"
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ from typing import Any, Iterator
5
+
6
+ from mlx_vlm import load as mlx_vlm_load
7
+ from mlx_vlm import stream_generate as mlx_vlm_stream_generate
8
+
9
+ from backends.base import ModelBackend
10
+ from utils.vlm_utils import load_and_resize_images
11
+
12
+
13
+ class MlxVlmBackend(ModelBackend):
14
+ """`mlx_vlm` backend for vision-language models."""
15
+
16
+ def __init__(self) -> None:
17
+ self.model: Any | None = None
18
+ self.processor: Any | None = None
19
+ self.drafter: Any | None = None
20
+ self.drafter_kind: str | None = None
21
+ self.draft_block_size: int | None = None
22
+
23
+ def load(self, model_name: str) -> None:
24
+ self.model, self.processor = mlx_vlm_load(model_name)
25
+
26
+ def load_drafter(self, drafter_model: str) -> None:
27
+ from mlx_vlm.speculative.drafters import load_drafter
28
+ self.drafter, self.drafter_kind = load_drafter(drafter_model)
29
+ sys.stderr.write(f"Drafter loaded: {drafter_model} (kind={self.drafter_kind})\n")
30
+
31
+ def has_drafter(self) -> bool:
32
+ return self.drafter is not None
33
+
34
+ def get_tokenizer(self) -> Any:
35
+ return self.processor
36
+
37
+ def stream_generate(
38
+ self, prompt: str | list[int], options: dict, images: list | None = None,
39
+ prompt_cache: list | None = None,
40
+ ) -> Iterator[Any]:
41
+ if self.model is None or self.processor is None:
42
+ raise RuntimeError("Model is not loaded")
43
+
44
+ final_options = dict(options)
45
+ temperature = final_options.pop("temperature", 1.0)
46
+ max_tokens = final_options.pop("max_tokens", 1000)
47
+ top_p = final_options.pop("top_p", 0.0)
48
+ top_k = final_options.pop("top_k", 0)
49
+
50
+ processed_images = None
51
+ if images:
52
+ max_image_size = final_options.pop("max_image_size", 768)
53
+ processed_images = load_and_resize_images(images, max_image_size)
54
+
55
+ draft_kwargs = {}
56
+ if self.drafter:
57
+ draft_kwargs["draft_model"] = self.drafter
58
+ draft_kwargs["draft_kind"] = self.drafter_kind
59
+ if self.draft_block_size is not None:
60
+ draft_kwargs["draft_block_size"] = self.draft_block_size
61
+
62
+ # Generate and collect tokens
63
+ token_count = 0
64
+ for result in mlx_vlm_stream_generate(
65
+ self.model,
66
+ self.processor,
67
+ prompt,
68
+ image=processed_images,
69
+ max_tokens=max_tokens,
70
+ temperature=temperature,
71
+ top_p=top_p,
72
+ top_k=top_k,
73
+ **draft_kwargs,
74
+ ):
75
+ token_count += 1
76
+ yield result
77
+
78
+ # Output speculative decoding stats if drafter was used
79
+ if self.drafter and hasattr(self.drafter, 'accept_lens'):
80
+ accept_lens = self.drafter.accept_lens
81
+ if accept_lens:
82
+ avg_accepted = sum(accept_lens) / len(accept_lens)
83
+ # Show stats unless MLX_NO_STATS environment variable is set
84
+ import os
85
+ if not os.getenv('MLX_NO_STATS'):
86
+ sys.stderr.write(f"\n[Speculative Decoding Stats]\n")
87
+ sys.stderr.write(f" Rounds: {len(accept_lens)}\n")
88
+ sys.stderr.write(f" Average accepted tokens/round: {avg_accepted:.2f}\n")
89
+ sys.stderr.write(f" Total tokens generated: {token_count}\n")
90
+ sys.stderr.write(f" Speedup factor: {avg_accepted:.2f}x (theoretical)\n")
91
+ # Clear for next generation
92
+ self.drafter.accept_lens = []
93
+
94
+ def supports_vision(self) -> bool:
95
+ return True
96
+
97
+ @property
98
+ def model_kind(self) -> str:
99
+ return "vlm"
@@ -0,0 +1,93 @@
1
+ # This file contains code to use LLM-jp-4 models with mlx-lm on Apple Silicon.
2
+
3
+ from mlx_lm import load, stream_generate
4
+ from mlx_lm.sample_utils import make_sampler
5
+
6
+
7
+ def main():
8
+ model, tokenizer = load(
9
+ # "llm-jp/llm-jp-4-8b-instruct",
10
+ "llm-jp/llm-jp-4-8b-thinking",
11
+ tokenizer_config={"trust_remote_code": True},
12
+ )
13
+
14
+ messages = [
15
+ {"role": "user", "content": "日本語で自己紹介してください。"},
16
+ ]
17
+
18
+ prompt: str = tokenizer.apply_chat_template(
19
+ messages,
20
+ tokenize=False,
21
+ add_generation_prompt=True,
22
+ reasoning_effort="medium",
23
+ )
24
+
25
+ print("--- Prompt ---")
26
+ print(prompt)
27
+
28
+ input_ids = tokenizer.encode(prompt)
29
+
30
+ print("--- Input IDs ---")
31
+ print(input_ids)
32
+
33
+ generated_ids: list[int] = []
34
+
35
+ sampler = make_sampler(temp=0.7, top_p=0.9)
36
+
37
+ for resp in stream_generate(
38
+ model, tokenizer, prompt=input_ids,
39
+ max_tokens=1024, sampler=sampler,
40
+ ):
41
+ generated_ids.append(resp.token)
42
+
43
+ print("--- Generated IDs ---")
44
+ print(generated_ids)
45
+
46
+ response = tokenizer.decode(generated_ids)
47
+
48
+ print("\n--- Response ---")
49
+ print(response)
50
+
51
+ parsed = tokenizer.parse_response(response)
52
+
53
+ print("\n--- Parsed Response ---")
54
+ print("Role:", parsed.get("role"))
55
+ print("Thinking:", parsed.get("thinking"))
56
+ print("Content:", parsed.get("content"))
57
+
58
+ # Harmony parser is bundled as the parse_harmony_message method of the tokenizer.
59
+ # This function accepts a list of token IDs (not strings)
60
+ # and returns a list of Harmony's message objects with split tokens.
61
+
62
+ # To correctly parse the response,
63
+ # we need to include the prefill tokens for the assistant's response.
64
+ response_prefill = tokenizer.encode("<|start|>assistant")
65
+ parsed_harmony = tokenizer.parse_harmony_message(response_prefill + generated_ids)
66
+
67
+ print("\n--- Parsed Harmony Messages ---")
68
+ for i, message in enumerate(parsed_harmony, start=1):
69
+ print(f"Message {i}:")
70
+
71
+ # The end type can be "END", "CALL", or "INCOMPLETE".
72
+ print(" End Type:", message.end)
73
+
74
+ if message.role:
75
+ print(" Role Tokens:", message.role.token_ids)
76
+ print(" Role Text:", repr(tokenizer.decode(message.role.token_ids)))
77
+ print(" Role Start Position:", message.role.start)
78
+ if message.channel:
79
+ print(" Channel Tokens:", message.channel.token_ids)
80
+ print(" Channel Text:", repr(tokenizer.decode(message.channel.token_ids)))
81
+ print(" Channel Start Position:", message.channel.start)
82
+ if message.constrain:
83
+ print(" Constrain Tokens:", message.constrain.token_ids)
84
+ print(" Constrain Text:", repr(tokenizer.decode(message.constrain.token_ids)))
85
+ print(" Constrain Start Position:", message.constrain.start)
86
+ if message.content:
87
+ print(" Content Tokens:", message.content.token_ids)
88
+ print(" Content Text:", repr(tokenizer.decode(message.content.token_ids)))
89
+ print(" Content Start Position:", message.content.start)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
@@ -0,0 +1,165 @@
1
+ # Tool call example using LLM-jp-4 with mlx-lm on Apple Silicon.
2
+
3
+ from mlx_lm import load, stream_generate
4
+ from mlx_lm.sample_utils import make_sampler
5
+
6
+
7
+ def generate_response(model, tokenizer, input_ids, sampler):
8
+ generated_ids: list[int] = []
9
+ for resp in stream_generate(
10
+ model, tokenizer, prompt=input_ids,
11
+ max_tokens=1024, sampler=sampler,
12
+ ):
13
+ generated_ids.append(resp.token)
14
+ return generated_ids
15
+
16
+
17
+ def main():
18
+ model, tokenizer = load(
19
+ # "llm-jp/llm-jp-4-8b-thinking",
20
+ # "llm-jp/llm-jp-4-8b-instruct",
21
+ # "mlx-community/llm-jp-4-32b-a3b-thinking-4bit",
22
+ "mlx-community/Qwen3.6-27B-4bit",
23
+ # tokenizer_config={"trust_remote_code": True},
24
+ )
25
+
26
+ tools = [
27
+ {
28
+ "type": "function",
29
+ "function": {
30
+ "name": "get_current_time",
31
+ "description": "現在の日時を取得する",
32
+ "parameters": {
33
+ "type": "object",
34
+ "properties": {},
35
+ },
36
+ },
37
+ },
38
+ {
39
+ "type": "function",
40
+ "function": {
41
+ "name": "get_weather",
42
+ "description": "指定された都市の現在の天気を取得する",
43
+ "parameters": {
44
+ "type": "object",
45
+ "required": ["city"],
46
+ "properties": {
47
+ "city": {
48
+ "type": "string",
49
+ "description": "都市名(例: 東京、大阪)",
50
+ },
51
+ },
52
+ },
53
+ },
54
+ },
55
+ ]
56
+
57
+ messages = [
58
+ {"role": "developer", "content": "必要に応じて応答をTool Callに切り替えてください。functionsで定義されている機能を呼び出すことができます。"},
59
+ # \nツール実行形式: <|start|>assistant to=functions.get_current_time<|channel|>commentary json<|message|>{"locate": "Asia/Tokyo"}<|call|>
60
+
61
+ # few-shot: tool call → tool response の例
62
+ {"role": "user", "content": "今何時?"},
63
+ {
64
+ "role": "assistant",
65
+ "tool_calls": [{
66
+ "function": {
67
+ "name": "get_current_time",
68
+ "arguments": {"locate": "Asia/Tokyo"},
69
+ },
70
+ }],
71
+ },
72
+ {
73
+ "role": "tool",
74
+ "content": {"datetime": "2026-04-24T15:30:00+09:00"},
75
+ },
76
+ {
77
+ "role": "assistant",
78
+ "content": "現在の時刻は15時30分です。",
79
+ },
80
+ # 本番のリクエスト
81
+ {"role": "user", "content": '東京の天気を教えてください。'},
82
+ ]
83
+
84
+ sampler = make_sampler(temp=0.7, top_p=0.9)
85
+
86
+ # --- Turn 1: tool call生成 ---
87
+ prompt: str = tokenizer.apply_chat_template(
88
+ messages,
89
+ tools=tools,
90
+ tokenize=False,
91
+ add_generation_prompt=True,
92
+ trust_remote_code=True,
93
+ reasoning_effort="middle",
94
+ )
95
+
96
+ print("=== Turn 1: Tool Call ===")
97
+ print("--- Prompt ---")
98
+ print(prompt)
99
+
100
+ input_ids = tokenizer.encode(prompt)
101
+ generated_ids = generate_response(model, tokenizer, input_ids, sampler)
102
+ response = tokenizer.decode(generated_ids)
103
+
104
+ print("\n--- Raw Response ---")
105
+ print(response)
106
+
107
+ # Harmony parserでtool callを解析
108
+ response_prefill = tokenizer.encode("<|start|>assistant")
109
+ parsed_harmony = tokenizer.parse_harmony_message(response_prefill + generated_ids)
110
+
111
+ print("\n--- Parsed Harmony Messages ---")
112
+ for i, message in enumerate(parsed_harmony, start=1):
113
+ print(f"Message {i}:")
114
+ print(" End Type:", message.end)
115
+ if message.role:
116
+ print(" Role:", repr(tokenizer.decode(message.role.token_ids)))
117
+ if message.channel:
118
+ print(" Channel:", repr(tokenizer.decode(message.channel.token_ids)))
119
+ if message.constrain:
120
+ print(" Constrain:", repr(tokenizer.decode(message.constrain.token_ids)))
121
+ if message.content:
122
+ print(" Content:", repr(tokenizer.decode(message.content.token_ids)))
123
+
124
+ # # --- Turn 2: tool resultを渡して最終応答 ---
125
+ # messages.append({
126
+ # "role": "assistant",
127
+ # "tool_calls": [{
128
+ # "function": {
129
+ # "name": "get_weather",
130
+ # "arguments": '{"city": "東京"}',
131
+ # },
132
+ # }],
133
+ # })
134
+ # messages.append({
135
+ # "role": "tool",
136
+ # "content": '{"city": "東京", "weather": "晴れ", "temperature": 22, "humidity": 45}',
137
+ # })
138
+
139
+ # prompt2: str = tokenizer.apply_chat_template(
140
+ # messages,
141
+ # tools=tools,
142
+ # tokenize=False,
143
+ # add_generation_prompt=True,
144
+ # )
145
+
146
+ # print("\n\n=== Turn 2: Final Response ===")
147
+ # print("--- Prompt ---")
148
+ # print(prompt2)
149
+
150
+ # input_ids2 = tokenizer.encode(prompt2)
151
+ # generated_ids2 = generate_response(model, tokenizer, input_ids2, sampler)
152
+ # response2 = tokenizer.decode(generated_ids2)
153
+
154
+ # print("\n--- Raw Response ---")
155
+ # print(response2)
156
+
157
+ # parsed = tokenizer.parse_response(response2)
158
+ # print("\n--- Parsed Response ---")
159
+ # print("Role:", parsed.get("role"))
160
+ # print("Thinking:", parsed.get("thinking"))
161
+ # print("Content:", parsed.get("content"))
162
+
163
+
164
+ if __name__ == "__main__":
165
+ main()
@@ -0,0 +1,6 @@
1
+ from handlers.cache import handle_cache_prefill
2
+ from handlers.capabilities import handle_capabilities
3
+ from handlers.chat import handle_chat
4
+ from handlers.completion import handle_completion
5
+ from handlers.format_test import handle_format_test
6
+ from handlers.tokenize import handle_tokenize
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import sys
5
+
6
+ from backends.base import ModelBackend
7
+ from utils.prompt_builder import generate_merged_prompt, supports_chat_template
8
+
9
+
10
+ def handle_cache_prefill(
11
+ backend: ModelBackend,
12
+ capabilities: dict,
13
+ cache_path: str,
14
+ messages: list,
15
+ base_cache_path: str | None = None,
16
+ trim_to_tokens: int | None = None,
17
+ prefix_offsets: list[int] | None = None,
18
+ prefix_hashes: list[str] | None = None,
19
+ tools: list | None = None,
20
+ reasoning_effort: str | None = None,
21
+ ) -> None:
22
+ tokenizer = backend.get_tokenizer()
23
+
24
+ extra_kwargs = {}
25
+ if tools is not None:
26
+ extra_kwargs["tools"] = tools
27
+ if reasoning_effort is not None:
28
+ extra_kwargs["reasoning_effort"] = reasoning_effort
29
+ if supports_chat_template(tokenizer):
30
+ try:
31
+ prompt = tokenizer.apply_chat_template(
32
+ messages,
33
+ add_generation_prompt=False,
34
+ tokenize=False,
35
+ **extra_kwargs,
36
+ )
37
+ except TypeError:
38
+ try:
39
+ fallback_kwargs = {}
40
+ if tools is not None:
41
+ fallback_kwargs["tools"] = tools
42
+ prompt = tokenizer.apply_chat_template(
43
+ messages,
44
+ add_generation_prompt=False,
45
+ tokenize=False,
46
+ **fallback_kwargs,
47
+ )
48
+ except TypeError:
49
+ try:
50
+ prompt = tokenizer.apply_chat_template(
51
+ messages,
52
+ add_generation_prompt=False,
53
+ tokenize=False,
54
+ )
55
+ except Exception:
56
+ prompt = generate_merged_prompt(messages, capabilities)
57
+ sys.stderr.write(
58
+ "--- cache_prefill: fallback to generate_merged_prompt\n"
59
+ )
60
+ except Exception:
61
+ prompt = generate_merged_prompt(messages, capabilities)
62
+ sys.stderr.write(
63
+ "--- cache_prefill: fallback to generate_merged_prompt\n"
64
+ )
65
+ else:
66
+ prompt = generate_merged_prompt(messages, capabilities)
67
+
68
+ # Only show debug output if MLX_DEBUG environment variable is set
69
+ import os
70
+ if os.getenv('MLX_DEBUG'):
71
+ sys.stderr.write(f"--- cache_prefill {cache_path}\n")
72
+ result = backend.cache_prefill(
73
+ cache_path, prompt, base_cache_path,
74
+ trim_to_tokens=trim_to_tokens,
75
+ prefix_offsets=prefix_offsets,
76
+ prefix_hashes=prefix_hashes,
77
+ )
78
+ if prefix_offsets and prefix_hashes:
79
+ result["prefix_offsets"] = prefix_offsets
80
+ result["prefix_hashes"] = prefix_hashes
81
+ print(json.dumps(result), end="\0", flush=True)
@@ -0,0 +1,6 @@
1
+ import json
2
+
3
+
4
+ def handle_capabilities(capabilities: dict) -> None:
5
+ """capabilities API の処理。JSON出力してnull文字で終端"""
6
+ print(json.dumps(capabilities), end="\0", flush=True)