@modular-prompt/driver 0.4.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +597 -0
- package/dist/anthropic/anthropic-driver.d.ts +47 -0
- package/dist/anthropic/anthropic-driver.d.ts.map +1 -0
- package/dist/anthropic/anthropic-driver.js +217 -0
- package/dist/anthropic/anthropic-driver.js.map +1 -0
- package/dist/driver-registry/ai-service.d.ts +43 -0
- package/dist/driver-registry/ai-service.d.ts.map +1 -0
- package/dist/driver-registry/ai-service.js +77 -0
- package/dist/driver-registry/ai-service.js.map +1 -0
- package/dist/driver-registry/config-based-factory.d.ts +64 -0
- package/dist/driver-registry/config-based-factory.d.ts.map +1 -0
- package/dist/driver-registry/config-based-factory.js +90 -0
- package/dist/driver-registry/config-based-factory.js.map +1 -0
- package/dist/driver-registry/factory-helper.d.ts +49 -0
- package/dist/driver-registry/factory-helper.d.ts.map +1 -0
- package/dist/driver-registry/factory-helper.js +109 -0
- package/dist/driver-registry/factory-helper.js.map +1 -0
- package/dist/driver-registry/index.d.ts +9 -0
- package/dist/driver-registry/index.d.ts.map +1 -0
- package/dist/driver-registry/index.js +8 -0
- package/dist/driver-registry/index.js.map +1 -0
- package/dist/driver-registry/registry.d.ts +50 -0
- package/dist/driver-registry/registry.d.ts.map +1 -0
- package/dist/driver-registry/registry.js +208 -0
- package/dist/driver-registry/registry.js.map +1 -0
- package/dist/driver-registry/types.d.ts +108 -0
- package/dist/driver-registry/types.d.ts.map +1 -0
- package/dist/driver-registry/types.js +6 -0
- package/dist/driver-registry/types.js.map +1 -0
- package/dist/echo-driver.d.ts +88 -0
- package/dist/echo-driver.d.ts.map +1 -0
- package/dist/echo-driver.js +198 -0
- package/dist/echo-driver.js.map +1 -0
- package/dist/formatter/completion-formatter.d.ts +27 -0
- package/dist/formatter/completion-formatter.d.ts.map +1 -0
- package/dist/formatter/completion-formatter.js +84 -0
- package/dist/formatter/completion-formatter.js.map +1 -0
- package/dist/formatter/converter.d.ts +20 -0
- package/dist/formatter/converter.d.ts.map +1 -0
- package/dist/formatter/converter.js +176 -0
- package/dist/formatter/converter.js.map +1 -0
- package/dist/formatter/element-formatters/base.d.ts +34 -0
- package/dist/formatter/element-formatters/base.d.ts.map +1 -0
- package/dist/formatter/element-formatters/base.js +36 -0
- package/dist/formatter/element-formatters/base.js.map +1 -0
- package/dist/formatter/element-formatters/chunk.d.ts +11 -0
- package/dist/formatter/element-formatters/chunk.d.ts.map +1 -0
- package/dist/formatter/element-formatters/chunk.js +12 -0
- package/dist/formatter/element-formatters/chunk.js.map +1 -0
- package/dist/formatter/element-formatters/index.d.ts +14 -0
- package/dist/formatter/element-formatters/index.d.ts.map +1 -0
- package/dist/formatter/element-formatters/index.js +15 -0
- package/dist/formatter/element-formatters/index.js.map +1 -0
- package/dist/formatter/element-formatters/json.d.ts +11 -0
- package/dist/formatter/element-formatters/json.d.ts.map +1 -0
- package/dist/formatter/element-formatters/json.js +27 -0
- package/dist/formatter/element-formatters/json.js.map +1 -0
- package/dist/formatter/element-formatters/material.d.ts +11 -0
- package/dist/formatter/element-formatters/material.d.ts.map +1 -0
- package/dist/formatter/element-formatters/material.js +35 -0
- package/dist/formatter/element-formatters/material.js.map +1 -0
- package/dist/formatter/element-formatters/message.d.ts +13 -0
- package/dist/formatter/element-formatters/message.d.ts.map +1 -0
- package/dist/formatter/element-formatters/message.js +35 -0
- package/dist/formatter/element-formatters/message.js.map +1 -0
- package/dist/formatter/element-formatters/registry.d.ts +29 -0
- package/dist/formatter/element-formatters/registry.d.ts.map +1 -0
- package/dist/formatter/element-formatters/registry.js +82 -0
- package/dist/formatter/element-formatters/registry.js.map +1 -0
- package/dist/formatter/element-formatters/section.d.ts +18 -0
- package/dist/formatter/element-formatters/section.d.ts.map +1 -0
- package/dist/formatter/element-formatters/section.js +46 -0
- package/dist/formatter/element-formatters/section.js.map +1 -0
- package/dist/formatter/element-formatters/string-pattern.d.ts +22 -0
- package/dist/formatter/element-formatters/string-pattern.d.ts.map +1 -0
- package/dist/formatter/element-formatters/string-pattern.js +124 -0
- package/dist/formatter/element-formatters/string-pattern.js.map +1 -0
- package/dist/formatter/element-formatters/text.d.ts +11 -0
- package/dist/formatter/element-formatters/text.d.ts.map +1 -0
- package/dist/formatter/element-formatters/text.js +11 -0
- package/dist/formatter/element-formatters/text.js.map +1 -0
- package/dist/formatter/formatter.d.ts +24 -0
- package/dist/formatter/formatter.d.ts.map +1 -0
- package/dist/formatter/formatter.js +252 -0
- package/dist/formatter/formatter.js.map +1 -0
- package/dist/formatter/types.d.ts +91 -0
- package/dist/formatter/types.d.ts.map +1 -0
- package/dist/formatter/types.js +2 -0
- package/dist/formatter/types.js.map +1 -0
- package/dist/google-genai/google-genai-driver.d.ts +67 -0
- package/dist/google-genai/google-genai-driver.d.ts.map +1 -0
- package/dist/google-genai/google-genai-driver.js +351 -0
- package/dist/google-genai/google-genai-driver.js.map +1 -0
- package/dist/index.d.ts +17 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +23 -0
- package/dist/index.js.map +1 -0
- package/dist/mlx-ml/mlx-driver.d.ts +65 -0
- package/dist/mlx-ml/mlx-driver.d.ts.map +1 -0
- package/dist/mlx-ml/mlx-driver.js +235 -0
- package/dist/mlx-ml/mlx-driver.js.map +1 -0
- package/dist/mlx-ml/model-spec/index.d.ts +7 -0
- package/dist/mlx-ml/model-spec/index.d.ts.map +1 -0
- package/dist/mlx-ml/model-spec/index.js +7 -0
- package/dist/mlx-ml/model-spec/index.js.map +1 -0
- package/dist/mlx-ml/model-spec/types.d.ts +30 -0
- package/dist/mlx-ml/model-spec/types.d.ts.map +1 -0
- package/dist/mlx-ml/model-spec/types.js +7 -0
- package/dist/mlx-ml/model-spec/types.js.map +1 -0
- package/dist/mlx-ml/process/index.d.ts +33 -0
- package/dist/mlx-ml/process/index.d.ts.map +1 -0
- package/dist/mlx-ml/process/index.js +65 -0
- package/dist/mlx-ml/process/index.js.map +1 -0
- package/dist/mlx-ml/process/model-handlers.d.ts +58 -0
- package/dist/mlx-ml/process/model-handlers.d.ts.map +1 -0
- package/dist/mlx-ml/process/model-handlers.js +197 -0
- package/dist/mlx-ml/process/model-handlers.js.map +1 -0
- package/dist/mlx-ml/process/model-specific.d.ts +35 -0
- package/dist/mlx-ml/process/model-specific.d.ts.map +1 -0
- package/dist/mlx-ml/process/model-specific.js +35 -0
- package/dist/mlx-ml/process/model-specific.js.map +1 -0
- package/dist/mlx-ml/process/parameter-mapper.d.ts +17 -0
- package/dist/mlx-ml/process/parameter-mapper.d.ts.map +1 -0
- package/dist/mlx-ml/process/parameter-mapper.js +91 -0
- package/dist/mlx-ml/process/parameter-mapper.js.map +1 -0
- package/dist/mlx-ml/process/parameter-validator.d.ts +55 -0
- package/dist/mlx-ml/process/parameter-validator.d.ts.map +1 -0
- package/dist/mlx-ml/process/parameter-validator.js +203 -0
- package/dist/mlx-ml/process/parameter-validator.js.map +1 -0
- package/dist/mlx-ml/process/process-communication.d.ts +25 -0
- package/dist/mlx-ml/process/process-communication.d.ts.map +1 -0
- package/dist/mlx-ml/process/process-communication.js +117 -0
- package/dist/mlx-ml/process/process-communication.js.map +1 -0
- package/dist/mlx-ml/process/queue.d.ts +30 -0
- package/dist/mlx-ml/process/queue.d.ts.map +1 -0
- package/dist/mlx-ml/process/queue.js +147 -0
- package/dist/mlx-ml/process/queue.js.map +1 -0
- package/dist/mlx-ml/process/types.d.ts +97 -0
- package/dist/mlx-ml/process/types.d.ts.map +1 -0
- package/dist/mlx-ml/process/types.js +2 -0
- package/dist/mlx-ml/process/types.js.map +1 -0
- package/dist/mlx-ml/types.d.ts +66 -0
- package/dist/mlx-ml/types.d.ts.map +1 -0
- package/dist/mlx-ml/types.js +7 -0
- package/dist/mlx-ml/types.js.map +1 -0
- package/dist/ollama/ollama-driver.d.ts +15 -0
- package/dist/ollama/ollama-driver.d.ts.map +1 -0
- package/dist/ollama/ollama-driver.js +15 -0
- package/dist/ollama/ollama-driver.js.map +1 -0
- package/dist/openai/openai-driver.d.ts +71 -0
- package/dist/openai/openai-driver.d.ts.map +1 -0
- package/dist/openai/openai-driver.js +230 -0
- package/dist/openai/openai-driver.js.map +1 -0
- package/dist/test-driver.d.ts +78 -0
- package/dist/test-driver.d.ts.map +1 -0
- package/dist/test-driver.js +193 -0
- package/dist/test-driver.js.map +1 -0
- package/dist/types.d.ts +90 -0
- package/dist/types.d.ts.map +1 -0
- package/dist/types.js +2 -0
- package/dist/types.js.map +1 -0
- package/dist/vertexai/vertexai-driver.d.ts +63 -0
- package/dist/vertexai/vertexai-driver.d.ts.map +1 -0
- package/dist/vertexai/vertexai-driver.js +335 -0
- package/dist/vertexai/vertexai-driver.js.map +1 -0
- package/package.json +61 -0
- package/scripts/download-model.js +40 -0
- package/scripts/setup-mlx.js +53 -0
- package/src/mlx-ml/python/.python-version +1 -0
- package/src/mlx-ml/python/__main__.py +312 -0
- package/src/mlx-ml/python/chat_template_constraints.py +164 -0
- package/src/mlx-ml/python/pyproject.toml +19 -0
- package/src/mlx-ml/python/token_utils.py +262 -0
- package/src/mlx-ml/python/uv.lock +1029 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
|
|
3
|
+
import { execSync } from 'child_process';
|
|
4
|
+
import { existsSync } from 'fs';
|
|
5
|
+
import { join, dirname } from 'path';
|
|
6
|
+
import { fileURLToPath } from 'url';
|
|
7
|
+
|
|
8
|
+
const __dirname = dirname(fileURLToPath(import.meta.url));
|
|
9
|
+
const pythonDir = join(__dirname, '..', 'src', 'mlx-ml', 'python');
|
|
10
|
+
const distPythonDir = join(__dirname, '..', 'dist', 'mlx-ml', 'python');
|
|
11
|
+
|
|
12
|
+
console.log('🚀 Setting up MLX driver dependencies...\n');
|
|
13
|
+
|
|
14
|
+
// Check Python directory
|
|
15
|
+
const targetDir = existsSync(distPythonDir) ? distPythonDir : pythonDir;
|
|
16
|
+
|
|
17
|
+
if (!existsSync(targetDir)) {
|
|
18
|
+
console.log('⚠️ MLX Python directory not found. Skipping setup.');
|
|
19
|
+
console.log(' MLX driver will not be available.');
|
|
20
|
+
process.exit(0);
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
console.log(`📁 Working directory: ${targetDir}`);
|
|
24
|
+
|
|
25
|
+
// Check if uv is installed
|
|
26
|
+
try {
|
|
27
|
+
execSync('uv --version', { stdio: 'ignore' });
|
|
28
|
+
console.log('✅ uv is installed');
|
|
29
|
+
} catch {
|
|
30
|
+
console.log('⚠️ uv is not installed. Installing uv...');
|
|
31
|
+
try {
|
|
32
|
+
execSync('curl -LsSf https://astral.sh/uv/install.sh | sh', { stdio: 'inherit' });
|
|
33
|
+
console.log('✅ uv installed successfully');
|
|
34
|
+
} catch (error) {
|
|
35
|
+
console.error('❌ Failed to install uv. Please install it manually:');
|
|
36
|
+
console.error(' curl -LsSf https://astral.sh/uv/install.sh | sh');
|
|
37
|
+
console.error('\n MLX driver will not be available without uv.');
|
|
38
|
+
process.exit(0);
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
// Setup Python environment
|
|
43
|
+
console.log('\n📦 Setting up Python environment...');
|
|
44
|
+
try {
|
|
45
|
+
execSync('uv venv --python 3.13', { cwd: targetDir, stdio: 'inherit' });
|
|
46
|
+
execSync('uv pip install -e .', { cwd: targetDir, stdio: 'inherit' });
|
|
47
|
+
console.log('\n✅ MLX driver setup completed successfully!');
|
|
48
|
+
console.log(' You can now use MlxDriver from @moduler-prompt/driver');
|
|
49
|
+
} catch (error) {
|
|
50
|
+
console.error('❌ Failed to setup Python environment:', error.message);
|
|
51
|
+
console.error(' MLX driver will not be available.');
|
|
52
|
+
process.exit(0);
|
|
53
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.13
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import json
|
|
3
|
+
from mlx_lm import load, stream_generate
|
|
4
|
+
from mlx_lm.sample_utils import make_sampler
|
|
5
|
+
from token_utils import get_capabilities, is_eod_token
|
|
6
|
+
|
|
7
|
+
model_name = sys.argv[1] if len(sys.argv) > 1 else "mlx-community/gemma-3-270m-it-qat-4bit"
|
|
8
|
+
|
|
9
|
+
model, tokenizer = load(model_name)
|
|
10
|
+
|
|
11
|
+
# Capabilities情報の取得
|
|
12
|
+
capabilities = get_capabilities(tokenizer)
|
|
13
|
+
|
|
14
|
+
def read():
|
|
15
|
+
lines = []
|
|
16
|
+
data = None
|
|
17
|
+
eof = False
|
|
18
|
+
while not eof:
|
|
19
|
+
line = sys.stdin.readline()
|
|
20
|
+
# sys.stderr.write('line:' + line + '\n')
|
|
21
|
+
if not line:
|
|
22
|
+
eof = True
|
|
23
|
+
else:
|
|
24
|
+
lines.append(line)
|
|
25
|
+
try:
|
|
26
|
+
data = json.loads(''.join(lines))
|
|
27
|
+
except json.JSONDecodeError as e:
|
|
28
|
+
data = None
|
|
29
|
+
continue
|
|
30
|
+
break
|
|
31
|
+
return data
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def supports_chat_template():
|
|
35
|
+
"""
|
|
36
|
+
チャットテンプレートがサポートされているかを判定
|
|
37
|
+
|
|
38
|
+
apply_chat_templateメソッドの存在と、tokenizer.chat_templateの両方を確認する。
|
|
39
|
+
tokenizer.chat_templateが設定されていない場合、apply_chat_templateを呼んでも
|
|
40
|
+
エラーになるため、両方の条件をチェックする必要がある。
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
bool: チャットテンプレートがサポートされている場合True
|
|
44
|
+
"""
|
|
45
|
+
return (hasattr(tokenizer, 'apply_chat_template') and
|
|
46
|
+
hasattr(tokenizer, 'chat_template') and
|
|
47
|
+
tokenizer.chat_template is not None)
|
|
48
|
+
|
|
49
|
+
def handle_capabilities():
|
|
50
|
+
"""capabilities API の処理"""
|
|
51
|
+
print(json.dumps(capabilities), end='\0', flush=True)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def handle_format_test(messages, options=None):
|
|
55
|
+
"""フォーマットテスト API の処理(実際に生成せずフォーマットのみ)"""
|
|
56
|
+
if options is None:
|
|
57
|
+
options = {}
|
|
58
|
+
|
|
59
|
+
result = {
|
|
60
|
+
"formatted_prompt": None,
|
|
61
|
+
"template_applied": False,
|
|
62
|
+
"model_specific_processing": None,
|
|
63
|
+
"error": None
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
# チャットテンプレートが利用可能かチェック
|
|
68
|
+
if supports_chat_template():
|
|
69
|
+
# messagesはTypeScript側で既にモデル固有処理済み
|
|
70
|
+
result["model_specific_processing"] = messages
|
|
71
|
+
|
|
72
|
+
# プロンプト生成(フォーマットのみ)
|
|
73
|
+
primer = options.get('primer')
|
|
74
|
+
add_generation_prompt = True
|
|
75
|
+
tokenize = False # 常にテキストで返す
|
|
76
|
+
|
|
77
|
+
if primer is not None:
|
|
78
|
+
messages.append({'role': 'assistant', 'content': primer})
|
|
79
|
+
add_generation_prompt = False
|
|
80
|
+
|
|
81
|
+
formatted_prompt = tokenizer.apply_chat_template(
|
|
82
|
+
messages,
|
|
83
|
+
add_generation_prompt=add_generation_prompt,
|
|
84
|
+
tokenize=tokenize,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if primer is not None:
|
|
88
|
+
formatted_prompt = primer.join(formatted_prompt.split(primer)[0:-1]) + primer
|
|
89
|
+
|
|
90
|
+
result["formatted_prompt"] = formatted_prompt
|
|
91
|
+
result["template_applied"] = True
|
|
92
|
+
else:
|
|
93
|
+
# チャットテンプレートがない場合はcompletionフォーマット
|
|
94
|
+
formatted_prompt = generate_merged_prompt(messages)
|
|
95
|
+
primer = options.get('primer')
|
|
96
|
+
if primer is not None:
|
|
97
|
+
formatted_prompt += primer
|
|
98
|
+
|
|
99
|
+
result["formatted_prompt"] = formatted_prompt
|
|
100
|
+
result["template_applied"] = False
|
|
101
|
+
|
|
102
|
+
except Exception as e:
|
|
103
|
+
result["error"] = str(e)
|
|
104
|
+
|
|
105
|
+
print(json.dumps(result), end='\0', flush=True)
|
|
106
|
+
|
|
107
|
+
def handle_chat(messages, primer=None, options=None):
|
|
108
|
+
"""chat API の処理"""
|
|
109
|
+
if options is None:
|
|
110
|
+
options = {}
|
|
111
|
+
|
|
112
|
+
# チャットテンプレートが利用可能かチェック
|
|
113
|
+
if not supports_chat_template():
|
|
114
|
+
# チャットテンプレートがない場合はcompletionフォーマットに変換
|
|
115
|
+
# 注意: TypeScript側でAPIを決定するため、ここに来る場合は
|
|
116
|
+
# TypeScript側でchatが選択されたが、実際にはテンプレートがないケース
|
|
117
|
+
prompt = generate_merged_prompt(messages)
|
|
118
|
+
# primerはTypeScript側で既に追加されている場合があるので追加しない
|
|
119
|
+
# (TypeScript側でcompletion APIへの変換時に追加済み)
|
|
120
|
+
if primer is not None:
|
|
121
|
+
print(primer, end='', flush=True)
|
|
122
|
+
generate_text(prompt, options)
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
# messagesはTypeScript側で既にモデル固有処理済み
|
|
126
|
+
|
|
127
|
+
# プロンプト生成
|
|
128
|
+
add_generation_prompt = True
|
|
129
|
+
tokenize = False
|
|
130
|
+
|
|
131
|
+
if primer is not None:
|
|
132
|
+
messages.append({'role': 'assistant', 'content': primer})
|
|
133
|
+
add_generation_prompt = False
|
|
134
|
+
tokenize = False
|
|
135
|
+
|
|
136
|
+
prompt = tokenizer.apply_chat_template(
|
|
137
|
+
messages,
|
|
138
|
+
add_generation_prompt=add_generation_prompt,
|
|
139
|
+
tokenize=tokenize,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if primer is not None:
|
|
143
|
+
prompt = primer.join(prompt.split(primer)[0:-1]) + primer
|
|
144
|
+
print(primer, end='', flush=True)
|
|
145
|
+
|
|
146
|
+
generate_text(prompt, options)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def generate_merged_prompt(messages):
|
|
150
|
+
"""apply_chat_templateがない場合のプロンプト生成"""
|
|
151
|
+
# messagesはTypeScript側で既にmergeSystemMessages処理済み
|
|
152
|
+
# TypeScript側のformatterと同じフォーマットを維持
|
|
153
|
+
|
|
154
|
+
prompt_parts = []
|
|
155
|
+
special_tokens = capabilities.get('special_tokens', {})
|
|
156
|
+
|
|
157
|
+
for msg in messages:
|
|
158
|
+
role = msg['role'] # 小文字のまま
|
|
159
|
+
role_upper = role.upper()
|
|
160
|
+
|
|
161
|
+
# 1. 専用のspecial_tokenを探す
|
|
162
|
+
role_token = special_tokens.get(role)
|
|
163
|
+
|
|
164
|
+
if role_token and isinstance(role_token, dict) and 'start' in role_token:
|
|
165
|
+
# 専用トークンがある場合
|
|
166
|
+
start_token = role_token['start']['text']
|
|
167
|
+
end_token = role_token['end']['text']
|
|
168
|
+
prompt_parts.extend([
|
|
169
|
+
start_token,
|
|
170
|
+
msg['content'].strip(),
|
|
171
|
+
end_token,
|
|
172
|
+
'' # 空行で区切る
|
|
173
|
+
])
|
|
174
|
+
else:
|
|
175
|
+
# 2. 専用トークンがない場合、汎用blockトークンを探す
|
|
176
|
+
# blockやcontextなどの汎用的なペアトークンを探す
|
|
177
|
+
block_token = None
|
|
178
|
+
for candidate in ['block', 'context', 'quote', 'section']:
|
|
179
|
+
token = special_tokens.get(candidate)
|
|
180
|
+
if token and isinstance(token, dict) and 'start' in token:
|
|
181
|
+
block_token = token
|
|
182
|
+
break
|
|
183
|
+
|
|
184
|
+
if block_token:
|
|
185
|
+
# 汎用blockトークンがある場合: {block_begin}{role}:\n...{block_end}
|
|
186
|
+
start_token = block_token['start']['text']
|
|
187
|
+
end_token = block_token['end']['text']
|
|
188
|
+
prompt_parts.extend([
|
|
189
|
+
f'{start_token}{role_upper}:\n{msg["content"].strip()}',
|
|
190
|
+
end_token,
|
|
191
|
+
'' # 空行で区切る
|
|
192
|
+
])
|
|
193
|
+
else:
|
|
194
|
+
# 3. どちらもない場合は、HTMLコメント形式(フォールバック)
|
|
195
|
+
prompt_parts.extend([
|
|
196
|
+
f'<!-- begin of {role_upper} -->',
|
|
197
|
+
msg['content'].strip(),
|
|
198
|
+
f'<!-- end of {role_upper} -->',
|
|
199
|
+
'' # 空行で区切る
|
|
200
|
+
])
|
|
201
|
+
|
|
202
|
+
# 最後の空行を削除して、ダブル改行で結合
|
|
203
|
+
return '\n'.join(prompt_parts[:-1])
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def handle_completion(prompt, options=None):
|
|
207
|
+
"""completion API の処理"""
|
|
208
|
+
if options is None:
|
|
209
|
+
options = {}
|
|
210
|
+
|
|
211
|
+
# promptはTypeScript側で既にモデル固有処理済み
|
|
212
|
+
|
|
213
|
+
generate_text(prompt, options)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def generate_text(prompt, options):
|
|
217
|
+
"""テキスト生成の共通処理
|
|
218
|
+
|
|
219
|
+
注意: optionsはTypeScript側で事前にバリデーション済み
|
|
220
|
+
- temperatureパラメータはsamplerオブジェクトに変換
|
|
221
|
+
- サポートされていないパラメータはTS側でフィルタリング
|
|
222
|
+
"""
|
|
223
|
+
# デフォルトオプションの設定
|
|
224
|
+
default_options = {'max_tokens': 1000}
|
|
225
|
+
|
|
226
|
+
# temperatureパラメータを抽出してsamplerを作成
|
|
227
|
+
temperature = options.pop('temperature', 1.0) if 'temperature' in options else 1.0
|
|
228
|
+
top_p = options.pop('top_p', 0.0) if 'top_p' in options else 0.0
|
|
229
|
+
top_k = options.pop('top_k', 0) if 'top_k' in options else 0
|
|
230
|
+
|
|
231
|
+
# samplerオブジェクトを作成
|
|
232
|
+
sampler = make_sampler(temp=temperature, top_p=top_p, top_k=top_k)
|
|
233
|
+
|
|
234
|
+
# 残りのオプションとマージ
|
|
235
|
+
final_options = {**default_options, **options, 'sampler': sampler}
|
|
236
|
+
|
|
237
|
+
if isinstance(prompt, list): # tokenized
|
|
238
|
+
sys.stderr.write(f"--- prompt: len={len(prompt)}\n")
|
|
239
|
+
else:
|
|
240
|
+
sys.stderr.write(f"--- prompt\n{prompt}\n")
|
|
241
|
+
|
|
242
|
+
eos_detected = False
|
|
243
|
+
for response in stream_generate(model, tokenizer, prompt, **final_options):
|
|
244
|
+
# トークンIDによるEOS判定(より確実)
|
|
245
|
+
if is_eod_token(response, tokenizer):
|
|
246
|
+
eos_detected = True
|
|
247
|
+
print('\n', end='\0', flush=True)
|
|
248
|
+
break
|
|
249
|
+
if not eos_detected:
|
|
250
|
+
print(response.text.replace('\0', ''), end='', flush=True)
|
|
251
|
+
|
|
252
|
+
if not eos_detected:
|
|
253
|
+
print('\n', end='\0', flush=True)
|
|
254
|
+
|
|
255
|
+
def main():
|
|
256
|
+
while True:
|
|
257
|
+
req = read()
|
|
258
|
+
if req is None:
|
|
259
|
+
break
|
|
260
|
+
|
|
261
|
+
method = req.get('method')
|
|
262
|
+
if not method:
|
|
263
|
+
sys.stderr.write("Error: 'method' field is required\n")
|
|
264
|
+
print('\n', end='\0', flush=True)
|
|
265
|
+
continue
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
if method == 'capabilities':
|
|
269
|
+
handle_capabilities()
|
|
270
|
+
|
|
271
|
+
elif method == 'format_test':
|
|
272
|
+
messages = req.get('messages')
|
|
273
|
+
if not messages:
|
|
274
|
+
sys.stderr.write("Error: 'messages' field is required for format_test method\n")
|
|
275
|
+
print('\n', end='\0', flush=True)
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
options = req.get('options', {})
|
|
279
|
+
handle_format_test(messages, options)
|
|
280
|
+
|
|
281
|
+
elif method == 'chat':
|
|
282
|
+
messages = req.get('messages')
|
|
283
|
+
if not messages:
|
|
284
|
+
sys.stderr.write("Error: 'messages' field is required for chat method\n")
|
|
285
|
+
print('\n', end='\0', flush=True)
|
|
286
|
+
continue
|
|
287
|
+
|
|
288
|
+
primer = req.get('primer')
|
|
289
|
+
options = req.get('options', {})
|
|
290
|
+
handle_chat(messages, primer, options)
|
|
291
|
+
|
|
292
|
+
elif method == 'completion':
|
|
293
|
+
prompt = req.get('prompt')
|
|
294
|
+
if not prompt:
|
|
295
|
+
sys.stderr.write("Error: 'prompt' field is required for completion method\n")
|
|
296
|
+
print('\n', end='\0', flush=True)
|
|
297
|
+
continue
|
|
298
|
+
|
|
299
|
+
options = req.get('options', {})
|
|
300
|
+
handle_completion(prompt, options)
|
|
301
|
+
|
|
302
|
+
else:
|
|
303
|
+
sys.stderr.write(f"Error: Unknown method '{method}'\n")
|
|
304
|
+
print('\n', end='\0', flush=True)
|
|
305
|
+
|
|
306
|
+
except Exception as e:
|
|
307
|
+
sys.stderr.write(f"Error processing request: {e}\n")
|
|
308
|
+
print('\n', end='\0', flush=True)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
if __name__ == "__main__":
|
|
312
|
+
main()
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""
|
|
2
|
+
チャットテンプレートの制約検出
|
|
3
|
+
|
|
4
|
+
tokenizerのapply_chat_templateを使用して、
|
|
5
|
+
モデルがサポートするメッセージパターンの制約を検出する。
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def detect_chat_restrictions(tokenizer) -> dict:
|
|
10
|
+
"""
|
|
11
|
+
チャットテンプレートの制約を検出
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
tokenizer: HuggingFace tokenizer (apply_chat_template対応)
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
dict: chat_restrictions情報
|
|
18
|
+
{
|
|
19
|
+
"single_system_at_start": bool,
|
|
20
|
+
"max_system_messages": int,
|
|
21
|
+
"alternating_turns": bool,
|
|
22
|
+
"requires_user_last": bool,
|
|
23
|
+
"allow_empty_messages": bool
|
|
24
|
+
}
|
|
25
|
+
"""
|
|
26
|
+
if not hasattr(tokenizer, 'apply_chat_template'):
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
# テストパターンを実行
|
|
30
|
+
test_results = {}
|
|
31
|
+
for pattern in _get_test_patterns():
|
|
32
|
+
try:
|
|
33
|
+
tokenizer.apply_chat_template(
|
|
34
|
+
pattern['messages'],
|
|
35
|
+
tokenize=False,
|
|
36
|
+
add_generation_prompt=False
|
|
37
|
+
)
|
|
38
|
+
test_results[pattern['name']] = {'success': True}
|
|
39
|
+
except Exception as e:
|
|
40
|
+
test_results[pattern['name']] = {'error': str(e)}
|
|
41
|
+
|
|
42
|
+
# テスト結果から制約を推論
|
|
43
|
+
return _infer_restrictions_from_results(test_results)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _get_test_patterns():
|
|
47
|
+
"""テストパターンの定義"""
|
|
48
|
+
return [
|
|
49
|
+
# 基本パターン
|
|
50
|
+
{
|
|
51
|
+
'name': 'basic',
|
|
52
|
+
'messages': [
|
|
53
|
+
{'role': 'user', 'content': 'Hello'}
|
|
54
|
+
]
|
|
55
|
+
},
|
|
56
|
+
|
|
57
|
+
# システムメッセージ付き
|
|
58
|
+
{
|
|
59
|
+
'name': 'with-system',
|
|
60
|
+
'messages': [
|
|
61
|
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
|
62
|
+
{'role': 'user', 'content': 'Hello'}
|
|
63
|
+
]
|
|
64
|
+
},
|
|
65
|
+
|
|
66
|
+
# 複数システムメッセージ
|
|
67
|
+
{
|
|
68
|
+
'name': 'multi-system',
|
|
69
|
+
'messages': [
|
|
70
|
+
{'role': 'system', 'content': 'First system message.'},
|
|
71
|
+
{'role': 'system', 'content': 'Second system message.'},
|
|
72
|
+
{'role': 'user', 'content': 'Hello'}
|
|
73
|
+
]
|
|
74
|
+
},
|
|
75
|
+
|
|
76
|
+
# 連続ユーザーメッセージ
|
|
77
|
+
{
|
|
78
|
+
'name': 'consecutive-user',
|
|
79
|
+
'messages': [
|
|
80
|
+
{'role': 'user', 'content': 'First question'},
|
|
81
|
+
{'role': 'user', 'content': 'Second question'}
|
|
82
|
+
]
|
|
83
|
+
},
|
|
84
|
+
|
|
85
|
+
# アシスタントで終わる
|
|
86
|
+
{
|
|
87
|
+
'name': 'assistant-last',
|
|
88
|
+
'messages': [
|
|
89
|
+
{'role': 'user', 'content': 'Hello'},
|
|
90
|
+
{'role': 'assistant', 'content': 'Hi there!'}
|
|
91
|
+
]
|
|
92
|
+
},
|
|
93
|
+
|
|
94
|
+
# 交互の会話
|
|
95
|
+
{
|
|
96
|
+
'name': 'alternating',
|
|
97
|
+
'messages': [
|
|
98
|
+
{'role': 'user', 'content': 'Question 1'},
|
|
99
|
+
{'role': 'assistant', 'content': 'Answer 1'},
|
|
100
|
+
{'role': 'user', 'content': 'Question 2'}
|
|
101
|
+
]
|
|
102
|
+
},
|
|
103
|
+
|
|
104
|
+
# 空メッセージ
|
|
105
|
+
{
|
|
106
|
+
'name': 'empty-message',
|
|
107
|
+
'messages': [
|
|
108
|
+
{'role': 'user', 'content': ''}
|
|
109
|
+
]
|
|
110
|
+
},
|
|
111
|
+
|
|
112
|
+
# システムメッセージが途中にある
|
|
113
|
+
{
|
|
114
|
+
'name': 'system-middle',
|
|
115
|
+
'messages': [
|
|
116
|
+
{'role': 'user', 'content': 'First'},
|
|
117
|
+
{'role': 'system', 'content': 'System in middle'},
|
|
118
|
+
{'role': 'user', 'content': 'Second'}
|
|
119
|
+
]
|
|
120
|
+
}
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _infer_restrictions_from_results(test_results: dict) -> dict:
|
|
125
|
+
"""
|
|
126
|
+
テスト結果から制約を推論
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
test_results: テストパターン名をキーとした結果の辞書
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
dict: 検出された制約
|
|
133
|
+
"""
|
|
134
|
+
restrictions = {}
|
|
135
|
+
|
|
136
|
+
# システムメッセージの制約を検出
|
|
137
|
+
with_system = test_results.get('with-system')
|
|
138
|
+
multi_system = test_results.get('multi-system')
|
|
139
|
+
|
|
140
|
+
if with_system and 'error' in with_system:
|
|
141
|
+
# 単独のsystemメッセージもエラー → systemロール自体がサポートされていない
|
|
142
|
+
restrictions['max_system_messages'] = 0
|
|
143
|
+
elif multi_system and 'error' in multi_system:
|
|
144
|
+
# 複数はエラーだが単独は成功 → 最大1つまで
|
|
145
|
+
restrictions['single_system_at_start'] = True
|
|
146
|
+
restrictions['max_system_messages'] = 1
|
|
147
|
+
# それ以外(両方成功)→ max_system_messagesキーを設定しない(無制限)
|
|
148
|
+
|
|
149
|
+
# 連続ユーザーメッセージのテスト
|
|
150
|
+
consecutive_user = test_results.get('consecutive-user')
|
|
151
|
+
if consecutive_user and 'error' in consecutive_user:
|
|
152
|
+
restrictions['alternating_turns'] = True
|
|
153
|
+
|
|
154
|
+
# アシスタントで終わるテスト
|
|
155
|
+
assistant_last = test_results.get('assistant-last')
|
|
156
|
+
if assistant_last and 'error' in assistant_last:
|
|
157
|
+
restrictions['requires_user_last'] = True
|
|
158
|
+
|
|
159
|
+
# 空メッセージのテスト
|
|
160
|
+
empty_message = test_results.get('empty-message')
|
|
161
|
+
if empty_message and 'error' in empty_message:
|
|
162
|
+
restrictions['allow_empty_messages'] = False
|
|
163
|
+
|
|
164
|
+
return restrictions if restrictions else None
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "mlx_driver"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "MLX driver for moduler-prompt"
|
|
5
|
+
requires-python = ">=3.10,<3.14"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"flex>=6.14.1",
|
|
8
|
+
"hf-xet>=1.1.8",
|
|
9
|
+
"jinja2>=3.1.6",
|
|
10
|
+
"mlx-lm>=0.28.3",
|
|
11
|
+
"torch>=2.9.0",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[build-system]
|
|
15
|
+
requires = ["setuptools>=61.0"]
|
|
16
|
+
build-backend = "setuptools.build_meta"
|
|
17
|
+
|
|
18
|
+
[tool.setuptools]
|
|
19
|
+
py-modules = ["__main__", "chat_template_constraints", "token_utils"]
|