@modular-prompt/driver 0.4.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +597 -0
  3. package/dist/anthropic/anthropic-driver.d.ts +47 -0
  4. package/dist/anthropic/anthropic-driver.d.ts.map +1 -0
  5. package/dist/anthropic/anthropic-driver.js +217 -0
  6. package/dist/anthropic/anthropic-driver.js.map +1 -0
  7. package/dist/driver-registry/ai-service.d.ts +43 -0
  8. package/dist/driver-registry/ai-service.d.ts.map +1 -0
  9. package/dist/driver-registry/ai-service.js +77 -0
  10. package/dist/driver-registry/ai-service.js.map +1 -0
  11. package/dist/driver-registry/config-based-factory.d.ts +64 -0
  12. package/dist/driver-registry/config-based-factory.d.ts.map +1 -0
  13. package/dist/driver-registry/config-based-factory.js +90 -0
  14. package/dist/driver-registry/config-based-factory.js.map +1 -0
  15. package/dist/driver-registry/factory-helper.d.ts +49 -0
  16. package/dist/driver-registry/factory-helper.d.ts.map +1 -0
  17. package/dist/driver-registry/factory-helper.js +109 -0
  18. package/dist/driver-registry/factory-helper.js.map +1 -0
  19. package/dist/driver-registry/index.d.ts +9 -0
  20. package/dist/driver-registry/index.d.ts.map +1 -0
  21. package/dist/driver-registry/index.js +8 -0
  22. package/dist/driver-registry/index.js.map +1 -0
  23. package/dist/driver-registry/registry.d.ts +50 -0
  24. package/dist/driver-registry/registry.d.ts.map +1 -0
  25. package/dist/driver-registry/registry.js +208 -0
  26. package/dist/driver-registry/registry.js.map +1 -0
  27. package/dist/driver-registry/types.d.ts +108 -0
  28. package/dist/driver-registry/types.d.ts.map +1 -0
  29. package/dist/driver-registry/types.js +6 -0
  30. package/dist/driver-registry/types.js.map +1 -0
  31. package/dist/echo-driver.d.ts +88 -0
  32. package/dist/echo-driver.d.ts.map +1 -0
  33. package/dist/echo-driver.js +198 -0
  34. package/dist/echo-driver.js.map +1 -0
  35. package/dist/formatter/completion-formatter.d.ts +27 -0
  36. package/dist/formatter/completion-formatter.d.ts.map +1 -0
  37. package/dist/formatter/completion-formatter.js +84 -0
  38. package/dist/formatter/completion-formatter.js.map +1 -0
  39. package/dist/formatter/converter.d.ts +20 -0
  40. package/dist/formatter/converter.d.ts.map +1 -0
  41. package/dist/formatter/converter.js +176 -0
  42. package/dist/formatter/converter.js.map +1 -0
  43. package/dist/formatter/element-formatters/base.d.ts +34 -0
  44. package/dist/formatter/element-formatters/base.d.ts.map +1 -0
  45. package/dist/formatter/element-formatters/base.js +36 -0
  46. package/dist/formatter/element-formatters/base.js.map +1 -0
  47. package/dist/formatter/element-formatters/chunk.d.ts +11 -0
  48. package/dist/formatter/element-formatters/chunk.d.ts.map +1 -0
  49. package/dist/formatter/element-formatters/chunk.js +12 -0
  50. package/dist/formatter/element-formatters/chunk.js.map +1 -0
  51. package/dist/formatter/element-formatters/index.d.ts +14 -0
  52. package/dist/formatter/element-formatters/index.d.ts.map +1 -0
  53. package/dist/formatter/element-formatters/index.js +15 -0
  54. package/dist/formatter/element-formatters/index.js.map +1 -0
  55. package/dist/formatter/element-formatters/json.d.ts +11 -0
  56. package/dist/formatter/element-formatters/json.d.ts.map +1 -0
  57. package/dist/formatter/element-formatters/json.js +27 -0
  58. package/dist/formatter/element-formatters/json.js.map +1 -0
  59. package/dist/formatter/element-formatters/material.d.ts +11 -0
  60. package/dist/formatter/element-formatters/material.d.ts.map +1 -0
  61. package/dist/formatter/element-formatters/material.js +35 -0
  62. package/dist/formatter/element-formatters/material.js.map +1 -0
  63. package/dist/formatter/element-formatters/message.d.ts +13 -0
  64. package/dist/formatter/element-formatters/message.d.ts.map +1 -0
  65. package/dist/formatter/element-formatters/message.js +35 -0
  66. package/dist/formatter/element-formatters/message.js.map +1 -0
  67. package/dist/formatter/element-formatters/registry.d.ts +29 -0
  68. package/dist/formatter/element-formatters/registry.d.ts.map +1 -0
  69. package/dist/formatter/element-formatters/registry.js +82 -0
  70. package/dist/formatter/element-formatters/registry.js.map +1 -0
  71. package/dist/formatter/element-formatters/section.d.ts +18 -0
  72. package/dist/formatter/element-formatters/section.d.ts.map +1 -0
  73. package/dist/formatter/element-formatters/section.js +46 -0
  74. package/dist/formatter/element-formatters/section.js.map +1 -0
  75. package/dist/formatter/element-formatters/string-pattern.d.ts +22 -0
  76. package/dist/formatter/element-formatters/string-pattern.d.ts.map +1 -0
  77. package/dist/formatter/element-formatters/string-pattern.js +124 -0
  78. package/dist/formatter/element-formatters/string-pattern.js.map +1 -0
  79. package/dist/formatter/element-formatters/text.d.ts +11 -0
  80. package/dist/formatter/element-formatters/text.d.ts.map +1 -0
  81. package/dist/formatter/element-formatters/text.js +11 -0
  82. package/dist/formatter/element-formatters/text.js.map +1 -0
  83. package/dist/formatter/formatter.d.ts +24 -0
  84. package/dist/formatter/formatter.d.ts.map +1 -0
  85. package/dist/formatter/formatter.js +252 -0
  86. package/dist/formatter/formatter.js.map +1 -0
  87. package/dist/formatter/types.d.ts +91 -0
  88. package/dist/formatter/types.d.ts.map +1 -0
  89. package/dist/formatter/types.js +2 -0
  90. package/dist/formatter/types.js.map +1 -0
  91. package/dist/google-genai/google-genai-driver.d.ts +67 -0
  92. package/dist/google-genai/google-genai-driver.d.ts.map +1 -0
  93. package/dist/google-genai/google-genai-driver.js +351 -0
  94. package/dist/google-genai/google-genai-driver.js.map +1 -0
  95. package/dist/index.d.ts +17 -0
  96. package/dist/index.d.ts.map +1 -0
  97. package/dist/index.js +23 -0
  98. package/dist/index.js.map +1 -0
  99. package/dist/mlx-ml/mlx-driver.d.ts +65 -0
  100. package/dist/mlx-ml/mlx-driver.d.ts.map +1 -0
  101. package/dist/mlx-ml/mlx-driver.js +235 -0
  102. package/dist/mlx-ml/mlx-driver.js.map +1 -0
  103. package/dist/mlx-ml/model-spec/index.d.ts +7 -0
  104. package/dist/mlx-ml/model-spec/index.d.ts.map +1 -0
  105. package/dist/mlx-ml/model-spec/index.js +7 -0
  106. package/dist/mlx-ml/model-spec/index.js.map +1 -0
  107. package/dist/mlx-ml/model-spec/types.d.ts +30 -0
  108. package/dist/mlx-ml/model-spec/types.d.ts.map +1 -0
  109. package/dist/mlx-ml/model-spec/types.js +7 -0
  110. package/dist/mlx-ml/model-spec/types.js.map +1 -0
  111. package/dist/mlx-ml/process/index.d.ts +33 -0
  112. package/dist/mlx-ml/process/index.d.ts.map +1 -0
  113. package/dist/mlx-ml/process/index.js +65 -0
  114. package/dist/mlx-ml/process/index.js.map +1 -0
  115. package/dist/mlx-ml/process/model-handlers.d.ts +58 -0
  116. package/dist/mlx-ml/process/model-handlers.d.ts.map +1 -0
  117. package/dist/mlx-ml/process/model-handlers.js +197 -0
  118. package/dist/mlx-ml/process/model-handlers.js.map +1 -0
  119. package/dist/mlx-ml/process/model-specific.d.ts +35 -0
  120. package/dist/mlx-ml/process/model-specific.d.ts.map +1 -0
  121. package/dist/mlx-ml/process/model-specific.js +35 -0
  122. package/dist/mlx-ml/process/model-specific.js.map +1 -0
  123. package/dist/mlx-ml/process/parameter-mapper.d.ts +17 -0
  124. package/dist/mlx-ml/process/parameter-mapper.d.ts.map +1 -0
  125. package/dist/mlx-ml/process/parameter-mapper.js +91 -0
  126. package/dist/mlx-ml/process/parameter-mapper.js.map +1 -0
  127. package/dist/mlx-ml/process/parameter-validator.d.ts +55 -0
  128. package/dist/mlx-ml/process/parameter-validator.d.ts.map +1 -0
  129. package/dist/mlx-ml/process/parameter-validator.js +203 -0
  130. package/dist/mlx-ml/process/parameter-validator.js.map +1 -0
  131. package/dist/mlx-ml/process/process-communication.d.ts +25 -0
  132. package/dist/mlx-ml/process/process-communication.d.ts.map +1 -0
  133. package/dist/mlx-ml/process/process-communication.js +117 -0
  134. package/dist/mlx-ml/process/process-communication.js.map +1 -0
  135. package/dist/mlx-ml/process/queue.d.ts +30 -0
  136. package/dist/mlx-ml/process/queue.d.ts.map +1 -0
  137. package/dist/mlx-ml/process/queue.js +147 -0
  138. package/dist/mlx-ml/process/queue.js.map +1 -0
  139. package/dist/mlx-ml/process/types.d.ts +97 -0
  140. package/dist/mlx-ml/process/types.d.ts.map +1 -0
  141. package/dist/mlx-ml/process/types.js +2 -0
  142. package/dist/mlx-ml/process/types.js.map +1 -0
  143. package/dist/mlx-ml/types.d.ts +66 -0
  144. package/dist/mlx-ml/types.d.ts.map +1 -0
  145. package/dist/mlx-ml/types.js +7 -0
  146. package/dist/mlx-ml/types.js.map +1 -0
  147. package/dist/ollama/ollama-driver.d.ts +15 -0
  148. package/dist/ollama/ollama-driver.d.ts.map +1 -0
  149. package/dist/ollama/ollama-driver.js +15 -0
  150. package/dist/ollama/ollama-driver.js.map +1 -0
  151. package/dist/openai/openai-driver.d.ts +71 -0
  152. package/dist/openai/openai-driver.d.ts.map +1 -0
  153. package/dist/openai/openai-driver.js +230 -0
  154. package/dist/openai/openai-driver.js.map +1 -0
  155. package/dist/test-driver.d.ts +78 -0
  156. package/dist/test-driver.d.ts.map +1 -0
  157. package/dist/test-driver.js +193 -0
  158. package/dist/test-driver.js.map +1 -0
  159. package/dist/types.d.ts +90 -0
  160. package/dist/types.d.ts.map +1 -0
  161. package/dist/types.js +2 -0
  162. package/dist/types.js.map +1 -0
  163. package/dist/vertexai/vertexai-driver.d.ts +63 -0
  164. package/dist/vertexai/vertexai-driver.d.ts.map +1 -0
  165. package/dist/vertexai/vertexai-driver.js +335 -0
  166. package/dist/vertexai/vertexai-driver.js.map +1 -0
  167. package/package.json +61 -0
  168. package/scripts/download-model.js +40 -0
  169. package/scripts/setup-mlx.js +53 -0
  170. package/src/mlx-ml/python/.python-version +1 -0
  171. package/src/mlx-ml/python/__main__.py +312 -0
  172. package/src/mlx-ml/python/chat_template_constraints.py +164 -0
  173. package/src/mlx-ml/python/pyproject.toml +19 -0
  174. package/src/mlx-ml/python/token_utils.py +262 -0
  175. package/src/mlx-ml/python/uv.lock +1029 -0
@@ -0,0 +1,53 @@
1
+ #!/usr/bin/env node
2
+
3
+ import { execSync } from 'child_process';
4
+ import { existsSync } from 'fs';
5
+ import { join, dirname } from 'path';
6
+ import { fileURLToPath } from 'url';
7
+
8
+ const __dirname = dirname(fileURLToPath(import.meta.url));
9
+ const pythonDir = join(__dirname, '..', 'src', 'mlx-ml', 'python');
10
+ const distPythonDir = join(__dirname, '..', 'dist', 'mlx-ml', 'python');
11
+
12
+ console.log('🚀 Setting up MLX driver dependencies...\n');
13
+
14
+ // Check Python directory
15
+ const targetDir = existsSync(distPythonDir) ? distPythonDir : pythonDir;
16
+
17
+ if (!existsSync(targetDir)) {
18
+ console.log('⚠️ MLX Python directory not found. Skipping setup.');
19
+ console.log(' MLX driver will not be available.');
20
+ process.exit(0);
21
+ }
22
+
23
+ console.log(`📁 Working directory: ${targetDir}`);
24
+
25
+ // Check if uv is installed
26
+ try {
27
+ execSync('uv --version', { stdio: 'ignore' });
28
+ console.log('✅ uv is installed');
29
+ } catch {
30
+ console.log('⚠️ uv is not installed. Installing uv...');
31
+ try {
32
+ execSync('curl -LsSf https://astral.sh/uv/install.sh | sh', { stdio: 'inherit' });
33
+ console.log('✅ uv installed successfully');
34
+ } catch (error) {
35
+ console.error('❌ Failed to install uv. Please install it manually:');
36
+ console.error(' curl -LsSf https://astral.sh/uv/install.sh | sh');
37
+ console.error('\n MLX driver will not be available without uv.');
38
+ process.exit(0);
39
+ }
40
+ }
41
+
42
+ // Setup Python environment
43
+ console.log('\n📦 Setting up Python environment...');
44
+ try {
45
+ execSync('uv venv --python 3.13', { cwd: targetDir, stdio: 'inherit' });
46
+ execSync('uv pip install -e .', { cwd: targetDir, stdio: 'inherit' });
47
+ console.log('\n✅ MLX driver setup completed successfully!');
48
+ console.log(' You can now use MlxDriver from @moduler-prompt/driver');
49
+ } catch (error) {
50
+ console.error('❌ Failed to setup Python environment:', error.message);
51
+ console.error(' MLX driver will not be available.');
52
+ process.exit(0);
53
+ }
@@ -0,0 +1 @@
1
+ 3.13
@@ -0,0 +1,312 @@
1
+ import sys
2
+ import json
3
+ from mlx_lm import load, stream_generate
4
+ from mlx_lm.sample_utils import make_sampler
5
+ from token_utils import get_capabilities, is_eod_token
6
+
7
+ model_name = sys.argv[1] if len(sys.argv) > 1 else "mlx-community/gemma-3-270m-it-qat-4bit"
8
+
9
+ model, tokenizer = load(model_name)
10
+
11
+ # Capabilities情報の取得
12
+ capabilities = get_capabilities(tokenizer)
13
+
14
+ def read():
15
+ lines = []
16
+ data = None
17
+ eof = False
18
+ while not eof:
19
+ line = sys.stdin.readline()
20
+ # sys.stderr.write('line:' + line + '\n')
21
+ if not line:
22
+ eof = True
23
+ else:
24
+ lines.append(line)
25
+ try:
26
+ data = json.loads(''.join(lines))
27
+ except json.JSONDecodeError as e:
28
+ data = None
29
+ continue
30
+ break
31
+ return data
32
+
33
+
34
+ def supports_chat_template():
35
+ """
36
+ チャットテンプレートがサポートされているかを判定
37
+
38
+ apply_chat_templateメソッドの存在と、tokenizer.chat_templateの両方を確認する。
39
+ tokenizer.chat_templateが設定されていない場合、apply_chat_templateを呼んでも
40
+ エラーになるため、両方の条件をチェックする必要がある。
41
+
42
+ Returns:
43
+ bool: チャットテンプレートがサポートされている場合True
44
+ """
45
+ return (hasattr(tokenizer, 'apply_chat_template') and
46
+ hasattr(tokenizer, 'chat_template') and
47
+ tokenizer.chat_template is not None)
48
+
49
+ def handle_capabilities():
50
+ """capabilities API の処理"""
51
+ print(json.dumps(capabilities), end='\0', flush=True)
52
+
53
+
54
+ def handle_format_test(messages, options=None):
55
+ """フォーマットテスト API の処理(実際に生成せずフォーマットのみ)"""
56
+ if options is None:
57
+ options = {}
58
+
59
+ result = {
60
+ "formatted_prompt": None,
61
+ "template_applied": False,
62
+ "model_specific_processing": None,
63
+ "error": None
64
+ }
65
+
66
+ try:
67
+ # チャットテンプレートが利用可能かチェック
68
+ if supports_chat_template():
69
+ # messagesはTypeScript側で既にモデル固有処理済み
70
+ result["model_specific_processing"] = messages
71
+
72
+ # プロンプト生成(フォーマットのみ)
73
+ primer = options.get('primer')
74
+ add_generation_prompt = True
75
+ tokenize = False # 常にテキストで返す
76
+
77
+ if primer is not None:
78
+ messages.append({'role': 'assistant', 'content': primer})
79
+ add_generation_prompt = False
80
+
81
+ formatted_prompt = tokenizer.apply_chat_template(
82
+ messages,
83
+ add_generation_prompt=add_generation_prompt,
84
+ tokenize=tokenize,
85
+ )
86
+
87
+ if primer is not None:
88
+ formatted_prompt = primer.join(formatted_prompt.split(primer)[0:-1]) + primer
89
+
90
+ result["formatted_prompt"] = formatted_prompt
91
+ result["template_applied"] = True
92
+ else:
93
+ # チャットテンプレートがない場合はcompletionフォーマット
94
+ formatted_prompt = generate_merged_prompt(messages)
95
+ primer = options.get('primer')
96
+ if primer is not None:
97
+ formatted_prompt += primer
98
+
99
+ result["formatted_prompt"] = formatted_prompt
100
+ result["template_applied"] = False
101
+
102
+ except Exception as e:
103
+ result["error"] = str(e)
104
+
105
+ print(json.dumps(result), end='\0', flush=True)
106
+
107
+ def handle_chat(messages, primer=None, options=None):
108
+ """chat API の処理"""
109
+ if options is None:
110
+ options = {}
111
+
112
+ # チャットテンプレートが利用可能かチェック
113
+ if not supports_chat_template():
114
+ # チャットテンプレートがない場合はcompletionフォーマットに変換
115
+ # 注意: TypeScript側でAPIを決定するため、ここに来る場合は
116
+ # TypeScript側でchatが選択されたが、実際にはテンプレートがないケース
117
+ prompt = generate_merged_prompt(messages)
118
+ # primerはTypeScript側で既に追加されている場合があるので追加しない
119
+ # (TypeScript側でcompletion APIへの変換時に追加済み)
120
+ if primer is not None:
121
+ print(primer, end='', flush=True)
122
+ generate_text(prompt, options)
123
+ return
124
+
125
+ # messagesはTypeScript側で既にモデル固有処理済み
126
+
127
+ # プロンプト生成
128
+ add_generation_prompt = True
129
+ tokenize = False
130
+
131
+ if primer is not None:
132
+ messages.append({'role': 'assistant', 'content': primer})
133
+ add_generation_prompt = False
134
+ tokenize = False
135
+
136
+ prompt = tokenizer.apply_chat_template(
137
+ messages,
138
+ add_generation_prompt=add_generation_prompt,
139
+ tokenize=tokenize,
140
+ )
141
+
142
+ if primer is not None:
143
+ prompt = primer.join(prompt.split(primer)[0:-1]) + primer
144
+ print(primer, end='', flush=True)
145
+
146
+ generate_text(prompt, options)
147
+
148
+
149
+ def generate_merged_prompt(messages):
150
+ """apply_chat_templateがない場合のプロンプト生成"""
151
+ # messagesはTypeScript側で既にmergeSystemMessages処理済み
152
+ # TypeScript側のformatterと同じフォーマットを維持
153
+
154
+ prompt_parts = []
155
+ special_tokens = capabilities.get('special_tokens', {})
156
+
157
+ for msg in messages:
158
+ role = msg['role'] # 小文字のまま
159
+ role_upper = role.upper()
160
+
161
+ # 1. 専用のspecial_tokenを探す
162
+ role_token = special_tokens.get(role)
163
+
164
+ if role_token and isinstance(role_token, dict) and 'start' in role_token:
165
+ # 専用トークンがある場合
166
+ start_token = role_token['start']['text']
167
+ end_token = role_token['end']['text']
168
+ prompt_parts.extend([
169
+ start_token,
170
+ msg['content'].strip(),
171
+ end_token,
172
+ '' # 空行で区切る
173
+ ])
174
+ else:
175
+ # 2. 専用トークンがない場合、汎用blockトークンを探す
176
+ # blockやcontextなどの汎用的なペアトークンを探す
177
+ block_token = None
178
+ for candidate in ['block', 'context', 'quote', 'section']:
179
+ token = special_tokens.get(candidate)
180
+ if token and isinstance(token, dict) and 'start' in token:
181
+ block_token = token
182
+ break
183
+
184
+ if block_token:
185
+ # 汎用blockトークンがある場合: {block_begin}{role}:\n...{block_end}
186
+ start_token = block_token['start']['text']
187
+ end_token = block_token['end']['text']
188
+ prompt_parts.extend([
189
+ f'{start_token}{role_upper}:\n{msg["content"].strip()}',
190
+ end_token,
191
+ '' # 空行で区切る
192
+ ])
193
+ else:
194
+ # 3. どちらもない場合は、HTMLコメント形式(フォールバック)
195
+ prompt_parts.extend([
196
+ f'<!-- begin of {role_upper} -->',
197
+ msg['content'].strip(),
198
+ f'<!-- end of {role_upper} -->',
199
+ '' # 空行で区切る
200
+ ])
201
+
202
+ # 最後の空行を削除して、ダブル改行で結合
203
+ return '\n'.join(prompt_parts[:-1])
204
+
205
+
206
+ def handle_completion(prompt, options=None):
207
+ """completion API の処理"""
208
+ if options is None:
209
+ options = {}
210
+
211
+ # promptはTypeScript側で既にモデル固有処理済み
212
+
213
+ generate_text(prompt, options)
214
+
215
+
216
+ def generate_text(prompt, options):
217
+ """テキスト生成の共通処理
218
+
219
+ 注意: optionsはTypeScript側で事前にバリデーション済み
220
+ - temperatureパラメータはsamplerオブジェクトに変換
221
+ - サポートされていないパラメータはTS側でフィルタリング
222
+ """
223
+ # デフォルトオプションの設定
224
+ default_options = {'max_tokens': 1000}
225
+
226
+ # temperatureパラメータを抽出してsamplerを作成
227
+ temperature = options.pop('temperature', 1.0) if 'temperature' in options else 1.0
228
+ top_p = options.pop('top_p', 0.0) if 'top_p' in options else 0.0
229
+ top_k = options.pop('top_k', 0) if 'top_k' in options else 0
230
+
231
+ # samplerオブジェクトを作成
232
+ sampler = make_sampler(temp=temperature, top_p=top_p, top_k=top_k)
233
+
234
+ # 残りのオプションとマージ
235
+ final_options = {**default_options, **options, 'sampler': sampler}
236
+
237
+ if isinstance(prompt, list): # tokenized
238
+ sys.stderr.write(f"--- prompt: len={len(prompt)}\n")
239
+ else:
240
+ sys.stderr.write(f"--- prompt\n{prompt}\n")
241
+
242
+ eos_detected = False
243
+ for response in stream_generate(model, tokenizer, prompt, **final_options):
244
+ # トークンIDによるEOS判定(より確実)
245
+ if is_eod_token(response, tokenizer):
246
+ eos_detected = True
247
+ print('\n', end='\0', flush=True)
248
+ break
249
+ if not eos_detected:
250
+ print(response.text.replace('\0', ''), end='', flush=True)
251
+
252
+ if not eos_detected:
253
+ print('\n', end='\0', flush=True)
254
+
255
+ def main():
256
+ while True:
257
+ req = read()
258
+ if req is None:
259
+ break
260
+
261
+ method = req.get('method')
262
+ if not method:
263
+ sys.stderr.write("Error: 'method' field is required\n")
264
+ print('\n', end='\0', flush=True)
265
+ continue
266
+
267
+ try:
268
+ if method == 'capabilities':
269
+ handle_capabilities()
270
+
271
+ elif method == 'format_test':
272
+ messages = req.get('messages')
273
+ if not messages:
274
+ sys.stderr.write("Error: 'messages' field is required for format_test method\n")
275
+ print('\n', end='\0', flush=True)
276
+ continue
277
+
278
+ options = req.get('options', {})
279
+ handle_format_test(messages, options)
280
+
281
+ elif method == 'chat':
282
+ messages = req.get('messages')
283
+ if not messages:
284
+ sys.stderr.write("Error: 'messages' field is required for chat method\n")
285
+ print('\n', end='\0', flush=True)
286
+ continue
287
+
288
+ primer = req.get('primer')
289
+ options = req.get('options', {})
290
+ handle_chat(messages, primer, options)
291
+
292
+ elif method == 'completion':
293
+ prompt = req.get('prompt')
294
+ if not prompt:
295
+ sys.stderr.write("Error: 'prompt' field is required for completion method\n")
296
+ print('\n', end='\0', flush=True)
297
+ continue
298
+
299
+ options = req.get('options', {})
300
+ handle_completion(prompt, options)
301
+
302
+ else:
303
+ sys.stderr.write(f"Error: Unknown method '{method}'\n")
304
+ print('\n', end='\0', flush=True)
305
+
306
+ except Exception as e:
307
+ sys.stderr.write(f"Error processing request: {e}\n")
308
+ print('\n', end='\0', flush=True)
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()
@@ -0,0 +1,164 @@
1
+ """
2
+ チャットテンプレートの制約検出
3
+
4
+ tokenizerのapply_chat_templateを使用して、
5
+ モデルがサポートするメッセージパターンの制約を検出する。
6
+ """
7
+
8
+
9
+ def detect_chat_restrictions(tokenizer) -> dict:
10
+ """
11
+ チャットテンプレートの制約を検出
12
+
13
+ Args:
14
+ tokenizer: HuggingFace tokenizer (apply_chat_template対応)
15
+
16
+ Returns:
17
+ dict: chat_restrictions情報
18
+ {
19
+ "single_system_at_start": bool,
20
+ "max_system_messages": int,
21
+ "alternating_turns": bool,
22
+ "requires_user_last": bool,
23
+ "allow_empty_messages": bool
24
+ }
25
+ """
26
+ if not hasattr(tokenizer, 'apply_chat_template'):
27
+ return None
28
+
29
+ # テストパターンを実行
30
+ test_results = {}
31
+ for pattern in _get_test_patterns():
32
+ try:
33
+ tokenizer.apply_chat_template(
34
+ pattern['messages'],
35
+ tokenize=False,
36
+ add_generation_prompt=False
37
+ )
38
+ test_results[pattern['name']] = {'success': True}
39
+ except Exception as e:
40
+ test_results[pattern['name']] = {'error': str(e)}
41
+
42
+ # テスト結果から制約を推論
43
+ return _infer_restrictions_from_results(test_results)
44
+
45
+
46
+ def _get_test_patterns():
47
+ """テストパターンの定義"""
48
+ return [
49
+ # 基本パターン
50
+ {
51
+ 'name': 'basic',
52
+ 'messages': [
53
+ {'role': 'user', 'content': 'Hello'}
54
+ ]
55
+ },
56
+
57
+ # システムメッセージ付き
58
+ {
59
+ 'name': 'with-system',
60
+ 'messages': [
61
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
62
+ {'role': 'user', 'content': 'Hello'}
63
+ ]
64
+ },
65
+
66
+ # 複数システムメッセージ
67
+ {
68
+ 'name': 'multi-system',
69
+ 'messages': [
70
+ {'role': 'system', 'content': 'First system message.'},
71
+ {'role': 'system', 'content': 'Second system message.'},
72
+ {'role': 'user', 'content': 'Hello'}
73
+ ]
74
+ },
75
+
76
+ # 連続ユーザーメッセージ
77
+ {
78
+ 'name': 'consecutive-user',
79
+ 'messages': [
80
+ {'role': 'user', 'content': 'First question'},
81
+ {'role': 'user', 'content': 'Second question'}
82
+ ]
83
+ },
84
+
85
+ # アシスタントで終わる
86
+ {
87
+ 'name': 'assistant-last',
88
+ 'messages': [
89
+ {'role': 'user', 'content': 'Hello'},
90
+ {'role': 'assistant', 'content': 'Hi there!'}
91
+ ]
92
+ },
93
+
94
+ # 交互の会話
95
+ {
96
+ 'name': 'alternating',
97
+ 'messages': [
98
+ {'role': 'user', 'content': 'Question 1'},
99
+ {'role': 'assistant', 'content': 'Answer 1'},
100
+ {'role': 'user', 'content': 'Question 2'}
101
+ ]
102
+ },
103
+
104
+ # 空メッセージ
105
+ {
106
+ 'name': 'empty-message',
107
+ 'messages': [
108
+ {'role': 'user', 'content': ''}
109
+ ]
110
+ },
111
+
112
+ # システムメッセージが途中にある
113
+ {
114
+ 'name': 'system-middle',
115
+ 'messages': [
116
+ {'role': 'user', 'content': 'First'},
117
+ {'role': 'system', 'content': 'System in middle'},
118
+ {'role': 'user', 'content': 'Second'}
119
+ ]
120
+ }
121
+ ]
122
+
123
+
124
+ def _infer_restrictions_from_results(test_results: dict) -> dict:
125
+ """
126
+ テスト結果から制約を推論
127
+
128
+ Args:
129
+ test_results: テストパターン名をキーとした結果の辞書
130
+
131
+ Returns:
132
+ dict: 検出された制約
133
+ """
134
+ restrictions = {}
135
+
136
+ # システムメッセージの制約を検出
137
+ with_system = test_results.get('with-system')
138
+ multi_system = test_results.get('multi-system')
139
+
140
+ if with_system and 'error' in with_system:
141
+ # 単独のsystemメッセージもエラー → systemロール自体がサポートされていない
142
+ restrictions['max_system_messages'] = 0
143
+ elif multi_system and 'error' in multi_system:
144
+ # 複数はエラーだが単独は成功 → 最大1つまで
145
+ restrictions['single_system_at_start'] = True
146
+ restrictions['max_system_messages'] = 1
147
+ # それ以外(両方成功)→ max_system_messagesキーを設定しない(無制限)
148
+
149
+ # 連続ユーザーメッセージのテスト
150
+ consecutive_user = test_results.get('consecutive-user')
151
+ if consecutive_user and 'error' in consecutive_user:
152
+ restrictions['alternating_turns'] = True
153
+
154
+ # アシスタントで終わるテスト
155
+ assistant_last = test_results.get('assistant-last')
156
+ if assistant_last and 'error' in assistant_last:
157
+ restrictions['requires_user_last'] = True
158
+
159
+ # 空メッセージのテスト
160
+ empty_message = test_results.get('empty-message')
161
+ if empty_message and 'error' in empty_message:
162
+ restrictions['allow_empty_messages'] = False
163
+
164
+ return restrictions if restrictions else None
@@ -0,0 +1,19 @@
1
+ [project]
2
+ name = "mlx_driver"
3
+ version = "0.1.0"
4
+ description = "MLX driver for moduler-prompt"
5
+ requires-python = ">=3.10,<3.14"
6
+ dependencies = [
7
+ "flex>=6.14.1",
8
+ "hf-xet>=1.1.8",
9
+ "jinja2>=3.1.6",
10
+ "mlx-lm>=0.28.3",
11
+ "torch>=2.9.0",
12
+ ]
13
+
14
+ [build-system]
15
+ requires = ["setuptools>=61.0"]
16
+ build-backend = "setuptools.build_meta"
17
+
18
+ [tool.setuptools]
19
+ py-modules = ["__main__", "chat_template_constraints", "token_utils"]