@modular-prompt/driver 0.11.15 → 0.13.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. package/README.md +25 -0
  2. package/dist/anthropic/anthropic-driver.d.ts +38 -8
  3. package/dist/anthropic/anthropic-driver.d.ts.map +1 -1
  4. package/dist/anthropic/anthropic-driver.js +180 -164
  5. package/dist/anthropic/anthropic-driver.js.map +1 -1
  6. package/dist/cache-controller.d.ts +28 -0
  7. package/dist/cache-controller.d.ts.map +1 -0
  8. package/dist/cache-controller.js +2 -0
  9. package/dist/cache-controller.js.map +1 -0
  10. package/dist/cache-utils.d.ts +20 -0
  11. package/dist/cache-utils.d.ts.map +1 -0
  12. package/dist/cache-utils.js +71 -0
  13. package/dist/cache-utils.js.map +1 -0
  14. package/dist/content-utils.d.ts +9 -0
  15. package/dist/content-utils.d.ts.map +1 -1
  16. package/dist/content-utils.js +47 -0
  17. package/dist/content-utils.js.map +1 -1
  18. package/dist/driver-registry/config-based-factory.d.ts.map +1 -1
  19. package/dist/driver-registry/config-based-factory.js +7 -0
  20. package/dist/driver-registry/config-based-factory.js.map +1 -1
  21. package/dist/driver-registry/factory-helper.d.ts.map +1 -1
  22. package/dist/driver-registry/factory-helper.js +7 -4
  23. package/dist/driver-registry/factory-helper.js.map +1 -1
  24. package/dist/driver-registry/types.d.ts +6 -0
  25. package/dist/driver-registry/types.d.ts.map +1 -1
  26. package/dist/formatter/converter.js +1 -1
  27. package/dist/formatter/converter.js.map +1 -1
  28. package/dist/google-genai/element-converter.d.ts +11 -0
  29. package/dist/google-genai/element-converter.d.ts.map +1 -0
  30. package/dist/google-genai/element-converter.js +126 -0
  31. package/dist/google-genai/element-converter.js.map +1 -0
  32. package/dist/google-genai/google-genai-cache-controller.d.ts +24 -0
  33. package/dist/google-genai/google-genai-cache-controller.d.ts.map +1 -0
  34. package/dist/google-genai/google-genai-cache-controller.js +127 -0
  35. package/dist/google-genai/google-genai-cache-controller.js.map +1 -0
  36. package/dist/google-genai/google-genai-driver.d.ts +5 -29
  37. package/dist/google-genai/google-genai-driver.d.ts.map +1 -1
  38. package/dist/google-genai/google-genai-driver.js +92 -255
  39. package/dist/google-genai/google-genai-driver.js.map +1 -1
  40. package/dist/index.d.ts +4 -0
  41. package/dist/index.d.ts.map +1 -1
  42. package/dist/index.js +3 -0
  43. package/dist/index.js.map +1 -1
  44. package/dist/mlx-ml/mlx-cache-controller.d.ts +66 -0
  45. package/dist/mlx-ml/mlx-cache-controller.d.ts.map +1 -0
  46. package/dist/mlx-ml/mlx-cache-controller.js +600 -0
  47. package/dist/mlx-ml/mlx-cache-controller.js.map +1 -0
  48. package/dist/mlx-ml/mlx-driver.d.ts +13 -8
  49. package/dist/mlx-ml/mlx-driver.d.ts.map +1 -1
  50. package/dist/mlx-ml/mlx-driver.js +202 -143
  51. package/dist/mlx-ml/mlx-driver.js.map +1 -1
  52. package/dist/mlx-ml/mlx-message-utils.d.ts +9 -0
  53. package/dist/mlx-ml/mlx-message-utils.d.ts.map +1 -0
  54. package/dist/mlx-ml/mlx-message-utils.js +71 -0
  55. package/dist/mlx-ml/mlx-message-utils.js.map +1 -0
  56. package/dist/mlx-ml/process/harmony-parser.d.ts +3 -0
  57. package/dist/mlx-ml/process/harmony-parser.d.ts.map +1 -0
  58. package/dist/mlx-ml/process/harmony-parser.js +175 -0
  59. package/dist/mlx-ml/process/harmony-parser.js.map +1 -0
  60. package/dist/mlx-ml/process/index.d.ts +7 -3
  61. package/dist/mlx-ml/process/index.d.ts.map +1 -1
  62. package/dist/mlx-ml/process/index.js +22 -7
  63. package/dist/mlx-ml/process/index.js.map +1 -1
  64. package/dist/mlx-ml/process/model-handlers.d.ts +11 -58
  65. package/dist/mlx-ml/process/model-handlers.d.ts.map +1 -1
  66. package/dist/mlx-ml/process/model-handlers.js +29 -11
  67. package/dist/mlx-ml/process/model-handlers.js.map +1 -1
  68. package/dist/mlx-ml/process/model-specific.d.ts +7 -0
  69. package/dist/mlx-ml/process/model-specific.d.ts.map +1 -1
  70. package/dist/mlx-ml/process/model-specific.js +3 -0
  71. package/dist/mlx-ml/process/model-specific.js.map +1 -1
  72. package/dist/mlx-ml/process/parameter-validator.d.ts.map +1 -1
  73. package/dist/mlx-ml/process/parameter-validator.js +10 -3
  74. package/dist/mlx-ml/process/parameter-validator.js.map +1 -1
  75. package/dist/mlx-ml/process/process-communication.d.ts +3 -0
  76. package/dist/mlx-ml/process/process-communication.d.ts.map +1 -1
  77. package/dist/mlx-ml/process/process-communication.js +13 -0
  78. package/dist/mlx-ml/process/process-communication.js.map +1 -1
  79. package/dist/mlx-ml/process/queue.d.ts +5 -2
  80. package/dist/mlx-ml/process/queue.d.ts.map +1 -1
  81. package/dist/mlx-ml/process/queue.js +103 -15
  82. package/dist/mlx-ml/process/queue.js.map +1 -1
  83. package/dist/mlx-ml/process/response-processor.d.ts +18 -0
  84. package/dist/mlx-ml/process/response-processor.d.ts.map +1 -0
  85. package/dist/mlx-ml/process/response-processor.js +24 -0
  86. package/dist/mlx-ml/process/response-processor.js.map +1 -0
  87. package/dist/mlx-ml/process/types.d.ts +51 -4
  88. package/dist/mlx-ml/process/types.d.ts.map +1 -1
  89. package/dist/mlx-ml/tool-call-parser.d.ts.map +1 -1
  90. package/dist/mlx-ml/tool-call-parser.js +44 -68
  91. package/dist/mlx-ml/tool-call-parser.js.map +1 -1
  92. package/dist/mlx-ml/types.d.ts +1 -0
  93. package/dist/mlx-ml/types.d.ts.map +1 -1
  94. package/dist/openai/openai-driver.d.ts +0 -2
  95. package/dist/openai/openai-driver.d.ts.map +1 -1
  96. package/dist/openai/openai-driver.js.map +1 -1
  97. package/dist/types.d.ts +9 -0
  98. package/dist/types.d.ts.map +1 -1
  99. package/package.json +7 -4
  100. package/src/mlx-ml/python/__main__.py +41 -425
  101. package/src/mlx-ml/python/backends/__init__.py +3 -0
  102. package/src/mlx-ml/python/backends/base.py +84 -0
  103. package/src/mlx-ml/python/backends/mlx_lm.py +202 -0
  104. package/src/mlx-ml/python/backends/mlx_vlm.py +99 -0
  105. package/src/mlx-ml/python/examples/example_basic.py +93 -0
  106. package/src/mlx-ml/python/examples/example_tool_call.py +165 -0
  107. package/src/mlx-ml/python/handlers/__init__.py +6 -0
  108. package/src/mlx-ml/python/handlers/cache.py +81 -0
  109. package/src/mlx-ml/python/handlers/capabilities.py +6 -0
  110. package/src/mlx-ml/python/handlers/chat.py +221 -0
  111. package/src/mlx-ml/python/handlers/completion.py +36 -0
  112. package/src/mlx-ml/python/handlers/format_test.py +70 -0
  113. package/src/mlx-ml/python/handlers/tokenize.py +63 -0
  114. package/src/mlx-ml/python/pyproject.toml +15 -5
  115. package/src/mlx-ml/python/server.py +126 -0
  116. package/src/mlx-ml/python/tests/__init__.py +0 -0
  117. package/src/mlx-ml/python/utils/__init__.py +0 -0
  118. package/src/mlx-ml/python/utils/prompt_builder.py +54 -0
  119. package/src/mlx-ml/python/{token_utils.py → utils/token_utils.py} +13 -5
  120. package/src/mlx-ml/python/uv.lock +299 -57
  121. /package/src/mlx-ml/python/{chat_template_constraints.py → utils/chat_template_constraints.py} +0 -0
  122. /package/src/mlx-ml/python/{vlm_utils.py → utils/vlm_utils.py} +0 -0
@@ -1,442 +1,58 @@
1
1
  import sys
2
- import json
3
- from vlm_utils import detect_model_kind, load_and_resize_images
4
- from token_utils import get_capabilities, is_eod_token
2
+
3
+ from backends import MlxLmBackend, MlxVlmBackend
4
+ from utils.token_utils import get_capabilities
5
+ from utils.vlm_utils import detect_model_kind
6
+ from server import Server
5
7
 
6
8
  model_name = sys.argv[1] if len(sys.argv) > 1 else "mlx-community/gemma-3-270m-it-qat-4bit"
7
9
  text_only = "--text-only" in sys.argv
8
10
 
9
- # モデル種別の判定とロード
10
- model_kind = "lm" if text_only else detect_model_kind(model_name)
11
-
12
- if model_kind == "vlm":
13
- from mlx_vlm import load as vlm_load, stream_generate as vlm_stream_generate
14
- try:
15
- model, processor = vlm_load(model_name)
16
- tokenizer = processor # capabilities取得用(VLMのprocessorもtokenizer互換)
17
- except (ValueError, Exception) as e:
18
- # mlx_vlm.models にモジュールが存在しても、実際のモデルに vision コンポーネントが
19
- # ない場合(例: Qwen3.5 テキストモデルが qwen2_vl として認識される)にフォールバック
20
- sys.stderr.write(f"VLM load failed, falling back to LM: {e}\n")
21
- model_kind = "lm"
22
- from mlx_lm import load, stream_generate
23
- from mlx_lm.sample_utils import make_sampler
24
- model, tokenizer = load(model_name)
25
- else:
26
- from mlx_lm import load, stream_generate
27
- from mlx_lm.sample_utils import make_sampler
28
- model, tokenizer = load(model_name)
29
-
30
- # Capabilities情報の取得
31
- capabilities = get_capabilities(tokenizer)
32
- capabilities["model_kind"] = model_kind
11
+ drafter_model = None
12
+ if "--drafter" in sys.argv:
13
+ idx = sys.argv.index("--drafter")
14
+ if idx + 1 < len(sys.argv):
15
+ drafter_model = sys.argv[idx + 1]
33
16
 
34
- def read():
35
- lines = []
36
- data = None
37
- eof = False
38
- while not eof:
39
- line = sys.stdin.readline()
40
- # sys.stderr.write('line:' + line + '\n')
41
- if not line:
42
- eof = True
43
- else:
44
- lines.append(line)
17
+ draft_block_size = None
18
+ if "--draft-block-size" in sys.argv:
19
+ idx = sys.argv.index("--draft-block-size")
20
+ if idx + 1 < len(sys.argv):
45
21
  try:
46
- data = json.loads(''.join(lines))
47
- except json.JSONDecodeError as e:
48
- data = None
49
- continue
50
- break
51
- return data
52
-
53
-
54
- def supports_chat_template():
55
- """
56
- チャットテンプレートがサポートされているかを判定
57
-
58
- apply_chat_templateメソッドの存在と、tokenizer.chat_templateの両方を確認する。
59
- tokenizer.chat_templateが設定されていない場合、apply_chat_templateを呼んでも
60
- エラーになるため、両方の条件をチェックする必要がある。
61
-
62
- Returns:
63
- bool: チャットテンプレートがサポートされている場合True
64
- """
65
- return (hasattr(tokenizer, 'apply_chat_template') and
66
- hasattr(tokenizer, 'chat_template') and
67
- tokenizer.chat_template is not None)
68
-
69
-
70
- def handle_capabilities():
71
- """capabilities API の処理"""
72
- print(json.dumps(capabilities), end='\0', flush=True)
73
-
74
-
75
- def handle_format_test(messages, options=None, tools=None):
76
- """フォーマットテスト API の処理(実際に生成せずフォーマットのみ)"""
77
- if options is None:
78
- options = {}
79
-
80
- result = {
81
- "formatted_prompt": None,
82
- "template_applied": False,
83
- "model_specific_processing": None,
84
- "error": None
85
- }
86
-
87
- try:
88
- # チャットテンプレートが利用可能かチェック
89
- if supports_chat_template():
90
- # messagesはTypeScript側で既にモデル固有処理済み
91
- result["model_specific_processing"] = messages
92
-
93
- # プロンプト生成(フォーマットのみ)
94
- primer = options.get('primer')
95
- add_generation_prompt = True
96
- tokenize = False # 常にテキストで返す
97
-
98
- if primer is not None:
99
- messages.append({'role': 'assistant', 'content': primer})
100
- add_generation_prompt = False
101
-
102
- # tools対応を試みる(テンプレートが対応していなければtools無しで実行)
103
- try:
104
- formatted_prompt = tokenizer.apply_chat_template(
105
- messages,
106
- tools=tools,
107
- add_generation_prompt=add_generation_prompt,
108
- tokenize=tokenize,
109
- )
110
- except TypeError:
111
- formatted_prompt = tokenizer.apply_chat_template(
112
- messages,
113
- add_generation_prompt=add_generation_prompt,
114
- tokenize=tokenize,
115
- )
116
-
117
- if primer is not None:
118
- formatted_prompt = primer.join(formatted_prompt.split(primer)[0:-1]) + primer
119
-
120
- result["formatted_prompt"] = formatted_prompt
121
- result["template_applied"] = True
122
- else:
123
- # チャットテンプレートがない場合はcompletionフォーマット
124
- formatted_prompt = generate_merged_prompt(messages)
125
- primer = options.get('primer')
126
- if primer is not None:
127
- formatted_prompt += primer
128
-
129
- result["formatted_prompt"] = formatted_prompt
130
- result["template_applied"] = False
131
-
132
- except Exception as e:
133
- result["error"] = str(e)
134
-
135
- print(json.dumps(result), end='\0', flush=True)
136
-
137
- def handle_chat(messages, primer=None, options=None, tools=None):
138
- """chat API の処理"""
139
- if options is None:
140
- options = {}
141
-
142
- # チャットテンプレートが利用可能かチェック
143
- if not supports_chat_template():
144
- # チャットテンプレートがない場合はcompletionフォーマットに変換
145
- prompt = generate_merged_prompt(messages)
146
- if primer is not None:
147
- print(primer, end='', flush=True)
148
- generate_text(prompt, options)
149
- return
150
-
151
- # プロンプト生成
152
- add_generation_prompt = True
153
- tokenize = False
154
-
155
- if primer is not None:
156
- messages.append({'role': 'assistant', 'content': primer})
157
- add_generation_prompt = False
158
- tokenize = False
159
-
160
- # tools対応を試みる(テンプレートが対応していなければtools無しで実行)
161
- try:
162
- prompt = tokenizer.apply_chat_template(
163
- messages,
164
- tools=tools,
165
- add_generation_prompt=add_generation_prompt,
166
- tokenize=tokenize,
167
- )
168
- except TypeError:
169
- prompt = tokenizer.apply_chat_template(
170
- messages,
171
- add_generation_prompt=add_generation_prompt,
172
- tokenize=tokenize,
173
- )
174
-
175
- if primer is not None:
176
- prompt = primer.join(prompt.split(primer)[0:-1]) + primer
177
- print(primer, end='', flush=True)
178
-
179
- generate_text(prompt, options)
180
-
181
-
182
- def generate_merged_prompt(messages):
183
- """apply_chat_templateがない場合のプロンプト生成"""
184
- # messagesはTypeScript側で既にmergeSystemMessages処理済み
185
- # TypeScript側のformatterと同じフォーマットを維持
186
-
187
- prompt_parts = []
188
- special_tokens = capabilities.get('special_tokens', {})
189
-
190
- for msg in messages:
191
- role = msg['role'] # 小文字のまま
192
- role_upper = role.upper()
193
-
194
- # 1. 専用のspecial_tokenを探す
195
- role_token = special_tokens.get(role)
22
+ draft_block_size = int(sys.argv[idx + 1])
23
+ except ValueError:
24
+ sys.stderr.write(f"Invalid --draft-block-size value: {sys.argv[idx + 1]}\n")
25
+ sys.exit(1)
196
26
 
197
- if role_token and isinstance(role_token, dict) and 'start' in role_token:
198
- # 専用トークンがある場合
199
- start_token = role_token['start']['text']
200
- end_token = role_token['end']['text']
201
- prompt_parts.extend([
202
- start_token,
203
- msg['content'].strip(),
204
- end_token,
205
- '' # 空行で区切る
206
- ])
207
- else:
208
- # 2. 専用トークンがない場合、汎用blockトークンを探す
209
- # blockやcontextなどの汎用的なペアトークンを探す
210
- block_token = None
211
- for candidate in ['block', 'context', 'quote', 'section']:
212
- token = special_tokens.get(candidate)
213
- if token and isinstance(token, dict) and 'start' in token:
214
- block_token = token
215
- break
216
27
 
217
- if block_token:
218
- # 汎用blockトークンがある場合: {block_begin}{role}:\n...{block_end}
219
- start_token = block_token['start']['text']
220
- end_token = block_token['end']['text']
221
- prompt_parts.extend([
222
- f'{start_token}{role_upper}:\n{msg["content"].strip()}',
223
- end_token,
224
- '' # 空行で区切る
225
- ])
226
- else:
227
- # 3. どちらもない場合は、HTMLコメント形式(フォールバック)
228
- prompt_parts.extend([
229
- f'<!-- begin of {role_upper} -->',
230
- msg['content'].strip(),
231
- f'<!-- end of {role_upper} -->',
232
- '' # 空行で区切る
233
- ])
28
+ def create_backend(model_name: str, text_only: bool = False):
29
+ model_kind = "lm" if text_only else detect_model_kind(model_name)
234
30
 
235
- # 最後の空行を削除して、ダブル改行で結合
236
- return '\n'.join(prompt_parts[:-1])
237
-
238
-
239
- def handle_completion(prompt, options=None, images=None, max_image_size=768):
240
- """completion API の処理
241
-
242
- VLMモデルの場合、TypeScript側でプロンプトにimageトークンが挿入済み。
243
- images が渡された場合は VLM 生成を使用する。
244
- """
245
- if options is None:
246
- options = {}
247
-
248
- # promptはTypeScript側で既にモデル固有処理済み
249
-
250
- if images:
251
- pil_images = load_and_resize_images(images, max_image_size)
252
-
253
- import re
254
- display_prompt = re.sub(r'(<\|image_pad\|>)+', '<|image_pad|>...', prompt)
255
- sys.stderr.write(f"--- vlm completion (images: {len(pil_images)}, max_size: {max_image_size})\n{display_prompt}\n")
256
-
257
- generate_text_vlm(prompt, pil_images, options)
258
- else:
259
- generate_text(prompt, options)
260
-
261
-
262
- def handle_chat_vlm(messages, images, options=None, max_image_size=768, tools=None, primer=None):
263
- """VLMモデル用のチャット処理
264
-
265
- messages: TypeScript側で画像プレースホルダー({type: "image"})が挿入済み
266
- images: 画像ファイルパスの配列(プレースホルダーと位置が対応)
267
- tools: ツール定義(テンプレートが対応している場合のみ使用)
268
- primer: アシスタント応答のプリフィックス
269
- """
270
- if options is None:
271
- options = {}
272
-
273
- # primer処理
274
- add_generation_prompt = True
275
- if primer is not None:
276
- messages.append({'role': 'assistant', 'content': primer})
277
- add_generation_prompt = False
278
-
279
- # processorのapply_chat_templateを直接使用
280
- # systemメッセージのマージはTypeScript側でchat_restrictionsに基づき処理済み
281
- # tools対応を試みる(テンプレートが対応していなければtools無しで実行)
282
- try:
283
- formatted_prompt = processor.apply_chat_template(
284
- messages,
285
- tools=tools,
286
- add_generation_prompt=add_generation_prompt,
287
- tokenize=False,
288
- )
289
- except TypeError:
290
- formatted_prompt = processor.apply_chat_template(
291
- messages,
292
- add_generation_prompt=add_generation_prompt,
293
- tokenize=False,
294
- )
295
-
296
- if primer is not None:
297
- formatted_prompt = primer.join(formatted_prompt.split(primer)[0:-1]) + primer
298
- print(primer, end='', flush=True)
299
-
300
- # 画像ファイルを読み込み・リサイズ
301
- pil_images = load_and_resize_images(images, max_image_size)
302
-
303
- # image_padトークンを省略して表示(大量のパディングで読みづらいため)
304
- import re
305
- display_prompt = re.sub(r'(<\|image_pad\|>)+', '<|image_pad|>...', formatted_prompt)
306
- sys.stderr.write(f"--- vlm prompt (images: {len(pil_images)}, max_size: {max_image_size})\n{display_prompt}\n")
307
-
308
- generate_text_vlm(formatted_prompt, pil_images, options)
309
-
310
-
311
- def generate_text_vlm(prompt, images, options, stop_token_ids=None):
312
- """VLMストリーミング生成"""
313
- temperature = options.pop('temperature', 1.0) if 'temperature' in options else 1.0
314
- max_tokens = options.pop('max_tokens', 1000) if 'max_tokens' in options else 1000
315
- top_p = options.pop('top_p', 0.0) if 'top_p' in options else 0.0
316
- top_k = options.pop('top_k', 0) if 'top_k' in options else 0
317
-
318
- for response in vlm_stream_generate(
319
- model, processor, prompt,
320
- image=images if images else None,
321
- max_tokens=max_tokens,
322
- temperature=temperature,
323
- top_p=top_p,
324
- top_k=top_k,
325
- ):
326
- # 追加 stop token チェック(tool call end 等)
327
- if stop_token_ids and hasattr(response, 'token') and int(response.token) in stop_token_ids:
328
- sys.stderr.write(f"--- stop token detected (vlm): {int(response.token)}\n")
329
- print('\n', end='\0', flush=True)
330
- return
331
- print(response.text.replace('\0', ''), end='', flush=True)
332
-
333
- print('\n', end='\0', flush=True)
334
-
335
-
336
- def generate_text(prompt, options):
337
- """テキスト生成の共通処理
338
-
339
- 注意: optionsはTypeScript側で事前にバリデーション済み
340
- - temperatureパラメータはsamplerオブジェクトに変換
341
- - サポートされていないパラメータはTS側でフィルタリング
342
- """
343
- # デフォルトオプションの設定
344
- default_options = {'max_tokens': 1000}
345
-
346
- # temperatureパラメータを抽出してsamplerを作成
347
- temperature = options.pop('temperature', 1.0) if 'temperature' in options else 1.0
348
- top_p = options.pop('top_p', 0.0) if 'top_p' in options else 0.0
349
- top_k = options.pop('top_k', 0) if 'top_k' in options else 0
350
-
351
- # samplerオブジェクトを作成
352
- sampler = make_sampler(temp=temperature, top_p=top_p, top_k=top_k)
353
-
354
- # 残りのオプションとマージ
355
- final_options = {**default_options, **options, 'sampler': sampler}
356
-
357
- if isinstance(prompt, list): # tokenized
358
- sys.stderr.write(f"--- prompt: len={len(prompt)}\n")
359
- else:
360
- sys.stderr.write(f"--- prompt\n{prompt}\n")
361
-
362
- eos_detected = False
363
- for response in stream_generate(model, tokenizer, prompt, **final_options):
364
- # トークンIDによるEOS判定(より確実)
365
- if is_eod_token(response, tokenizer):
366
- eos_detected = True
367
- print('\n', end='\0', flush=True)
368
- break
369
- if not eos_detected:
370
- print(response.text.replace('\0', ''), end='', flush=True)
371
-
372
- if not eos_detected:
373
- print('\n', end='\0', flush=True)
374
-
375
- def main():
376
- while True:
377
- req = read()
378
- if req is None:
379
- break
380
-
381
- method = req.get('method')
382
- if not method:
383
- sys.stderr.write("Error: 'method' field is required\n")
384
- print('\n', end='\0', flush=True)
385
- continue
386
-
31
+ if model_kind == "vlm":
32
+ backend = MlxVlmBackend()
387
33
  try:
388
- if method == 'capabilities':
389
- handle_capabilities()
390
-
391
- elif method == 'format_test':
392
- messages = req.get('messages')
393
- if not messages:
394
- sys.stderr.write("Error: 'messages' field is required for format_test method\n")
395
- print('\n', end='\0', flush=True)
396
- continue
34
+ backend.load(model_name)
35
+ return backend, "vlm"
36
+ except (ValueError, Exception) as e:
37
+ sys.stderr.write(f"VLM load failed, falling back to LM: {e}\n")
397
38
 
398
- options = req.get('options', {})
399
- tools = req.get('tools')
400
- handle_format_test(messages, options, tools)
39
+ backend = MlxLmBackend()
40
+ backend.load(model_name)
41
+ return backend, "lm"
401
42
 
402
- elif method == 'chat':
403
- messages = req.get('messages')
404
- if not messages:
405
- sys.stderr.write("Error: 'messages' field is required for chat method\n")
406
- print('\n', end='\0', flush=True)
407
- continue
408
43
 
409
- primer = req.get('primer')
410
- options = req.get('options', {})
411
- tools = req.get('tools')
412
- images = req.get('images', [])
44
+ if __name__ == "__main__":
45
+ backend, model_kind = create_backend(model_name, text_only)
413
46
 
414
- if model_kind == "vlm":
415
- max_image_size = req.get('maxImageSize', 768)
416
- handle_chat_vlm(messages, images, options, max_image_size, tools, primer)
417
- else:
418
- handle_chat(messages, primer, options, tools)
419
-
420
- elif method == 'completion':
421
- prompt = req.get('prompt')
422
- if not prompt:
423
- sys.stderr.write("Error: 'prompt' field is required for completion method\n")
424
- print('\n', end='\0', flush=True)
425
- continue
426
-
427
- options = req.get('options', {})
428
- images = req.get('images', [])
429
- max_image_size = req.get('maxImageSize', 768)
430
- handle_completion(prompt, options, images if images else None, max_image_size)
431
-
432
- else:
433
- sys.stderr.write(f"Error: Unknown method '{method}'\n")
434
- print('\n', end='\0', flush=True)
435
-
436
- except Exception as e:
437
- sys.stderr.write(f"Error processing request: {e}\n")
438
- print('\n', end='\0', flush=True)
47
+ if drafter_model:
48
+ backend.load_drafter(drafter_model)
49
+ if draft_block_size is not None and hasattr(backend, 'draft_block_size'):
50
+ backend.draft_block_size = draft_block_size
439
51
 
52
+ capabilities = get_capabilities(backend.get_tokenizer())
53
+ capabilities["model_kind"] = model_kind
54
+ if model_kind == "lm":
55
+ capabilities["methods"].append("cache_prefill")
440
56
 
441
- if __name__ == "__main__":
442
- main()
57
+ server = Server(backend, capabilities)
58
+ server.run()
@@ -0,0 +1,3 @@
1
+ from backends.base import ModelBackend
2
+ from backends.mlx_lm import MlxLmBackend
3
+ from backends.mlx_vlm import MlxVlmBackend
@@ -0,0 +1,84 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Iterator
3
+
4
+
5
+ class ModelBackend(ABC):
6
+ """Abstract base class for model backends."""
7
+
8
+ @abstractmethod
9
+ def load(self, model_name: str) -> None:
10
+ """Load the target model."""
11
+ raise NotImplementedError
12
+
13
+ @abstractmethod
14
+ def get_tokenizer(self) -> Any:
15
+ """Return the tokenizer or processor."""
16
+ raise NotImplementedError
17
+
18
+ @abstractmethod
19
+ def stream_generate(
20
+ self, prompt: str | list[int], options: dict, images: list | None = None,
21
+ prompt_cache: list | None = None,
22
+ ) -> Iterator[Any]:
23
+ """Stream generation results."""
24
+ raise NotImplementedError
25
+
26
+ @abstractmethod
27
+ def supports_vision(self) -> bool:
28
+ """Return whether image input is supported."""
29
+ raise NotImplementedError
30
+
31
+ @property
32
+ @abstractmethod
33
+ def model_kind(self) -> str:
34
+ """Return "lm" or "vlm"."""
35
+ raise NotImplementedError
36
+
37
+ def load_drafter(self, drafter_model: str) -> None:
38
+ """Load a drafter model for speculative decoding."""
39
+ raise NotImplementedError(
40
+ f"{type(self).__name__} does not support drafter models"
41
+ )
42
+
43
+ def has_drafter(self) -> bool:
44
+ """Return whether a drafter model is loaded."""
45
+ return False
46
+
47
+ def cache_prefill(
48
+ self,
49
+ cache_path: str,
50
+ prompt: str,
51
+ base_cache_path: str | None = None,
52
+ trim_to_tokens: int | None = None,
53
+ prefix_offsets: list[int] | None = None,
54
+ prefix_hashes: list[str] | None = None,
55
+ ) -> dict:
56
+ """Build a KV cache from a prompt prefix."""
57
+ raise NotImplementedError(
58
+ f"{type(self).__name__} does not support prompt caching"
59
+ )
60
+
61
+ def load_cache_from_file(self, cache_path: str) -> list | None:
62
+ """Load a prompt cache from file, or None."""
63
+ return None
64
+
65
+ def get_cache_offset(self, prompt_cache: list) -> int:
66
+ """Get the number of tokens stored in a loaded prompt cache."""
67
+ if not prompt_cache:
68
+ return 0
69
+ layer0 = prompt_cache[0]
70
+ if hasattr(layer0, 'offset'):
71
+ off = layer0.offset
72
+ return int(off.item() if hasattr(off, 'item') else off)
73
+ if hasattr(layer0, 'caches'):
74
+ for c in layer0.caches:
75
+ if hasattr(c, 'offset'):
76
+ off = c.offset
77
+ return int(off.item() if hasattr(off, 'item') else off)
78
+ try:
79
+ return int(layer0[0].shape[2])
80
+ except Exception:
81
+ pass
82
+ if hasattr(layer0, 'keys') and layer0.keys is not None:
83
+ return int(layer0.keys.shape[2])
84
+ return 0