@modular-prompt/driver 0.12.0 → 0.13.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/anthropic/anthropic-driver.d.ts +38 -8
- package/dist/anthropic/anthropic-driver.d.ts.map +1 -1
- package/dist/anthropic/anthropic-driver.js +180 -164
- package/dist/anthropic/anthropic-driver.js.map +1 -1
- package/dist/cache-controller.d.ts +31 -0
- package/dist/cache-controller.d.ts.map +1 -0
- package/dist/cache-controller.js +2 -0
- package/dist/cache-controller.js.map +1 -0
- package/dist/cache-utils.d.ts +20 -0
- package/dist/cache-utils.d.ts.map +1 -0
- package/dist/cache-utils.js +71 -0
- package/dist/cache-utils.js.map +1 -0
- package/dist/content-utils.d.ts.map +1 -1
- package/dist/content-utils.js +20 -0
- package/dist/content-utils.js.map +1 -1
- package/dist/driver-registry/config-based-factory.d.ts.map +1 -1
- package/dist/driver-registry/config-based-factory.js +7 -0
- package/dist/driver-registry/config-based-factory.js.map +1 -1
- package/dist/driver-registry/factory-helper.d.ts.map +1 -1
- package/dist/driver-registry/factory-helper.js +7 -4
- package/dist/driver-registry/factory-helper.js.map +1 -1
- package/dist/driver-registry/types.d.ts +6 -0
- package/dist/driver-registry/types.d.ts.map +1 -1
- package/dist/formatter/converter.js +1 -1
- package/dist/formatter/converter.js.map +1 -1
- package/dist/google-genai/element-converter.d.ts +11 -0
- package/dist/google-genai/element-converter.d.ts.map +1 -0
- package/dist/google-genai/element-converter.js +126 -0
- package/dist/google-genai/element-converter.js.map +1 -0
- package/dist/google-genai/google-genai-cache-controller.d.ts +24 -0
- package/dist/google-genai/google-genai-cache-controller.d.ts.map +1 -0
- package/dist/google-genai/google-genai-cache-controller.js +127 -0
- package/dist/google-genai/google-genai-cache-controller.js.map +1 -0
- package/dist/google-genai/google-genai-driver.d.ts +5 -29
- package/dist/google-genai/google-genai-driver.d.ts.map +1 -1
- package/dist/google-genai/google-genai-driver.js +92 -255
- package/dist/google-genai/google-genai-driver.js.map +1 -1
- package/dist/index.d.ts +4 -0
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +3 -0
- package/dist/index.js.map +1 -1
- package/dist/mlx-ml/mlx-cache-controller.d.ts +65 -0
- package/dist/mlx-ml/mlx-cache-controller.d.ts.map +1 -0
- package/dist/mlx-ml/mlx-cache-controller.js +624 -0
- package/dist/mlx-ml/mlx-cache-controller.js.map +1 -0
- package/dist/mlx-ml/mlx-driver.d.ts +12 -7
- package/dist/mlx-ml/mlx-driver.d.ts.map +1 -1
- package/dist/mlx-ml/mlx-driver.js +192 -124
- package/dist/mlx-ml/mlx-driver.js.map +1 -1
- package/dist/mlx-ml/mlx-message-utils.d.ts +9 -0
- package/dist/mlx-ml/mlx-message-utils.d.ts.map +1 -0
- package/dist/mlx-ml/mlx-message-utils.js +71 -0
- package/dist/mlx-ml/mlx-message-utils.js.map +1 -0
- package/dist/mlx-ml/process/index.d.ts +7 -3
- package/dist/mlx-ml/process/index.d.ts.map +1 -1
- package/dist/mlx-ml/process/index.js +22 -7
- package/dist/mlx-ml/process/index.js.map +1 -1
- package/dist/mlx-ml/process/model-handlers.d.ts +4 -59
- package/dist/mlx-ml/process/model-handlers.d.ts.map +1 -1
- package/dist/mlx-ml/process/model-handlers.js +15 -14
- package/dist/mlx-ml/process/model-handlers.js.map +1 -1
- package/dist/mlx-ml/process/model-specific.d.ts +7 -0
- package/dist/mlx-ml/process/model-specific.d.ts.map +1 -1
- package/dist/mlx-ml/process/model-specific.js +3 -0
- package/dist/mlx-ml/process/model-specific.js.map +1 -1
- package/dist/mlx-ml/process/process-communication.d.ts +3 -0
- package/dist/mlx-ml/process/process-communication.d.ts.map +1 -1
- package/dist/mlx-ml/process/process-communication.js +13 -0
- package/dist/mlx-ml/process/process-communication.js.map +1 -1
- package/dist/mlx-ml/process/queue.d.ts +5 -2
- package/dist/mlx-ml/process/queue.d.ts.map +1 -1
- package/dist/mlx-ml/process/queue.js +101 -14
- package/dist/mlx-ml/process/queue.js.map +1 -1
- package/dist/mlx-ml/process/response-processor.d.ts +10 -0
- package/dist/mlx-ml/process/response-processor.d.ts.map +1 -1
- package/dist/mlx-ml/process/response-processor.js +23 -1
- package/dist/mlx-ml/process/response-processor.js.map +1 -1
- package/dist/mlx-ml/process/types.d.ts +50 -4
- package/dist/mlx-ml/process/types.d.ts.map +1 -1
- package/dist/mlx-ml/tool-call-parser/content-parsers.d.ts +9 -0
- package/dist/mlx-ml/tool-call-parser/content-parsers.d.ts.map +1 -0
- package/dist/mlx-ml/tool-call-parser/content-parsers.js +223 -0
- package/dist/mlx-ml/tool-call-parser/content-parsers.js.map +1 -0
- package/dist/mlx-ml/tool-call-parser/detector.d.ts +16 -0
- package/dist/mlx-ml/tool-call-parser/detector.d.ts.map +1 -0
- package/dist/mlx-ml/tool-call-parser/detector.js +58 -0
- package/dist/mlx-ml/tool-call-parser/detector.js.map +1 -0
- package/dist/mlx-ml/tool-call-parser/index.d.ts +7 -0
- package/dist/mlx-ml/tool-call-parser/index.d.ts.map +1 -0
- package/dist/mlx-ml/tool-call-parser/index.js +136 -0
- package/dist/mlx-ml/tool-call-parser/index.js.map +1 -0
- package/dist/mlx-ml/tool-call-parser/tool-formatter.d.ts +8 -0
- package/dist/mlx-ml/tool-call-parser/tool-formatter.d.ts.map +1 -0
- package/dist/mlx-ml/tool-call-parser/tool-formatter.js +88 -0
- package/dist/mlx-ml/tool-call-parser/tool-formatter.js.map +1 -0
- package/dist/mlx-ml/tool-call-parser/types.d.ts +18 -0
- package/dist/mlx-ml/tool-call-parser/types.d.ts.map +1 -0
- package/dist/mlx-ml/tool-call-parser/types.js +2 -0
- package/dist/mlx-ml/tool-call-parser/types.js.map +1 -0
- package/dist/mlx-ml/tool-call-parser/utils.d.ts +5 -0
- package/dist/mlx-ml/tool-call-parser/utils.d.ts.map +1 -0
- package/dist/mlx-ml/tool-call-parser/utils.js +77 -0
- package/dist/mlx-ml/tool-call-parser/utils.js.map +1 -0
- package/dist/types.d.ts +2 -0
- package/dist/types.d.ts.map +1 -1
- package/package.json +9 -4
- package/src/mlx-ml/python/__main__.py +41 -449
- package/src/mlx-ml/python/backends/__init__.py +3 -0
- package/src/mlx-ml/python/backends/base.py +84 -0
- package/src/mlx-ml/python/backends/mlx_lm.py +202 -0
- package/src/mlx-ml/python/backends/mlx_vlm.py +99 -0
- package/src/mlx-ml/python/handlers/__init__.py +6 -0
- package/src/mlx-ml/python/handlers/cache.py +81 -0
- package/src/mlx-ml/python/handlers/capabilities.py +6 -0
- package/src/mlx-ml/python/handlers/chat.py +221 -0
- package/src/mlx-ml/python/handlers/completion.py +36 -0
- package/src/mlx-ml/python/handlers/format_test.py +70 -0
- package/src/mlx-ml/python/handlers/tokenize.py +63 -0
- package/src/mlx-ml/python/pyproject.toml +13 -3
- package/src/mlx-ml/python/server.py +126 -0
- package/src/mlx-ml/python/tests/__init__.py +0 -0
- package/src/mlx-ml/python/utils/__init__.py +0 -0
- package/src/mlx-ml/python/utils/prompt_builder.py +54 -0
- package/src/mlx-ml/python/{token_utils.py → utils/token_utils.py} +9 -40
- package/src/mlx-ml/python/uv.lock +266 -41
- package/dist/mlx-ml/tool-call-parser.d.ts +0 -30
- package/dist/mlx-ml/tool-call-parser.d.ts.map +0 -1
- package/dist/mlx-ml/tool-call-parser.js +0 -623
- package/dist/mlx-ml/tool-call-parser.js.map +0 -1
- /package/src/mlx-ml/python/{example_basic.py → examples/example_basic.py} +0 -0
- /package/src/mlx-ml/python/{example_tool_call.py → examples/example_tool_call.py} +0 -0
- /package/src/mlx-ml/python/{chat_template_constraints.py → utils/chat_template_constraints.py} +0 -0
- /package/src/mlx-ml/python/{vlm_utils.py → utils/vlm_utils.py} +0 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from backends.base import ModelBackend
|
|
4
|
+
from utils.prompt_builder import generate_merged_prompt, supports_chat_template
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def handle_format_test(
|
|
8
|
+
backend: ModelBackend,
|
|
9
|
+
capabilities: dict,
|
|
10
|
+
messages: list,
|
|
11
|
+
options: dict | None = None,
|
|
12
|
+
tools: list | None = None,
|
|
13
|
+
) -> None:
|
|
14
|
+
"""フォーマットテスト API の処理(実際に生成せずフォーマットのみ)"""
|
|
15
|
+
if options is None:
|
|
16
|
+
options = {}
|
|
17
|
+
|
|
18
|
+
tokenizer = backend.get_tokenizer()
|
|
19
|
+
result = {
|
|
20
|
+
"formatted_prompt": None,
|
|
21
|
+
"template_applied": False,
|
|
22
|
+
"model_specific_processing": None,
|
|
23
|
+
"error": None,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
if supports_chat_template(tokenizer):
|
|
28
|
+
result["model_specific_processing"] = messages
|
|
29
|
+
|
|
30
|
+
primer = options.get("primer")
|
|
31
|
+
add_generation_prompt = True
|
|
32
|
+
fmt_messages = list(messages)
|
|
33
|
+
|
|
34
|
+
if primer is not None:
|
|
35
|
+
fmt_messages.append({"role": "assistant", "content": primer})
|
|
36
|
+
add_generation_prompt = False
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
formatted_prompt = tokenizer.apply_chat_template(
|
|
40
|
+
fmt_messages,
|
|
41
|
+
tools=tools,
|
|
42
|
+
add_generation_prompt=add_generation_prompt,
|
|
43
|
+
tokenize=False,
|
|
44
|
+
)
|
|
45
|
+
except TypeError:
|
|
46
|
+
formatted_prompt = tokenizer.apply_chat_template(
|
|
47
|
+
fmt_messages,
|
|
48
|
+
add_generation_prompt=add_generation_prompt,
|
|
49
|
+
tokenize=False,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if primer is not None:
|
|
53
|
+
formatted_prompt = (
|
|
54
|
+
primer.join(formatted_prompt.split(primer)[0:-1]) + primer
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
result["formatted_prompt"] = formatted_prompt
|
|
58
|
+
result["template_applied"] = True
|
|
59
|
+
else:
|
|
60
|
+
formatted_prompt = generate_merged_prompt(messages, capabilities)
|
|
61
|
+
primer = options.get("primer")
|
|
62
|
+
if primer is not None:
|
|
63
|
+
formatted_prompt += primer
|
|
64
|
+
|
|
65
|
+
result["formatted_prompt"] = formatted_prompt
|
|
66
|
+
result["template_applied"] = False
|
|
67
|
+
except Exception as e:
|
|
68
|
+
result["error"] = str(e)
|
|
69
|
+
|
|
70
|
+
print(json.dumps(result), end="\0", flush=True)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from backends.base import ModelBackend
|
|
4
|
+
from utils.prompt_builder import generate_merged_prompt, supports_chat_template
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def handle_tokenize(
|
|
8
|
+
backend: ModelBackend,
|
|
9
|
+
capabilities: dict,
|
|
10
|
+
messages: list,
|
|
11
|
+
tools: list | None = None,
|
|
12
|
+
reasoning_effort: str | None = None,
|
|
13
|
+
) -> None:
|
|
14
|
+
"""メッセージをchat template適用後にトークン化して返す"""
|
|
15
|
+
tokenizer = backend.get_tokenizer()
|
|
16
|
+
|
|
17
|
+
result = {
|
|
18
|
+
"token_ids": None,
|
|
19
|
+
"token_count": 0,
|
|
20
|
+
"error": None,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
# apply_chat_templateのfallbackパターン (chat.py L165-188 と同じ)
|
|
25
|
+
# add_generation_prompt=False で、アシスタントの開始トークンは含めない
|
|
26
|
+
extra_kwargs = {}
|
|
27
|
+
if tools is not None:
|
|
28
|
+
extra_kwargs["tools"] = tools
|
|
29
|
+
if reasoning_effort is not None:
|
|
30
|
+
extra_kwargs["reasoning_effort"] = reasoning_effort
|
|
31
|
+
|
|
32
|
+
if supports_chat_template(tokenizer):
|
|
33
|
+
# chat.py と同じfallbackチェーン
|
|
34
|
+
prompt = None
|
|
35
|
+
for kwargs in [extra_kwargs, {k: v for k, v in extra_kwargs.items() if k == "tools"}, {}]:
|
|
36
|
+
try:
|
|
37
|
+
prompt = tokenizer.apply_chat_template(
|
|
38
|
+
messages,
|
|
39
|
+
add_generation_prompt=False,
|
|
40
|
+
tokenize=False,
|
|
41
|
+
**kwargs,
|
|
42
|
+
)
|
|
43
|
+
break
|
|
44
|
+
except TypeError:
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
if prompt is None:
|
|
48
|
+
prompt = str(messages)
|
|
49
|
+
else:
|
|
50
|
+
prompt = generate_merged_prompt(messages, capabilities)
|
|
51
|
+
|
|
52
|
+
# トークン化
|
|
53
|
+
add_special = tokenizer.bos_token is None or not prompt.startswith(
|
|
54
|
+
tokenizer.bos_token or ""
|
|
55
|
+
)
|
|
56
|
+
token_ids = tokenizer.encode(prompt, add_special_tokens=add_special)
|
|
57
|
+
|
|
58
|
+
result["token_ids"] = token_ids
|
|
59
|
+
result["token_count"] = len(token_ids)
|
|
60
|
+
except Exception as e:
|
|
61
|
+
result["error"] = str(e)
|
|
62
|
+
|
|
63
|
+
print(json.dumps(result), end="\0", flush=True)
|
|
@@ -9,16 +9,26 @@ dependencies = [
|
|
|
9
9
|
"jinja2==3.1.6",
|
|
10
10
|
"mlx>=0.31.2; sys_platform == 'darwin'",
|
|
11
11
|
"mlx-lm==0.31.3; sys_platform == 'darwin'",
|
|
12
|
-
"mlx-vlm==0.
|
|
12
|
+
"mlx-vlm==0.5.0",
|
|
13
13
|
"tokenizers==0.22.2",
|
|
14
14
|
"torch==2.9.1",
|
|
15
15
|
"torchvision==0.24.1",
|
|
16
|
-
"transformers
|
|
16
|
+
"transformers>=5.5.0",
|
|
17
17
|
]
|
|
18
18
|
|
|
19
|
+
[dependency-groups]
|
|
20
|
+
dev = ["pytest>=9.0"]
|
|
21
|
+
|
|
19
22
|
[build-system]
|
|
20
23
|
requires = ["setuptools>=61.0"]
|
|
21
24
|
build-backend = "setuptools.build_meta"
|
|
22
25
|
|
|
26
|
+
[tool.pytest.ini_options]
|
|
27
|
+
testpaths = ["tests"]
|
|
28
|
+
|
|
23
29
|
[tool.setuptools]
|
|
24
|
-
py-modules = ["__main__", "
|
|
30
|
+
py-modules = ["__main__", "server"]
|
|
31
|
+
|
|
32
|
+
[tool.setuptools.packages.find]
|
|
33
|
+
where = ["."]
|
|
34
|
+
include = ["backends*", "handlers*", "utils*"]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""JSON-RPC風サーバー: stdin/stdoutベースのリクエストディスパッチ"""
|
|
2
|
+
import json
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from backends.base import ModelBackend
|
|
6
|
+
from handlers import handle_cache_prefill, handle_capabilities, handle_chat, handle_completion, handle_format_test, handle_tokenize
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
MAX_READ_LINES = 10000
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def read():
|
|
13
|
+
lines = []
|
|
14
|
+
while True:
|
|
15
|
+
line = sys.stdin.readline()
|
|
16
|
+
if not line:
|
|
17
|
+
return None
|
|
18
|
+
lines.append(line)
|
|
19
|
+
if len(lines) > MAX_READ_LINES:
|
|
20
|
+
sys.stderr.write(f"Error: read buffer exceeded {MAX_READ_LINES} lines, discarding\n")
|
|
21
|
+
lines.clear()
|
|
22
|
+
continue
|
|
23
|
+
try:
|
|
24
|
+
return json.loads(''.join(lines))
|
|
25
|
+
except json.JSONDecodeError:
|
|
26
|
+
continue
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Server:
|
|
30
|
+
def __init__(self, backend: ModelBackend, capabilities: dict):
|
|
31
|
+
self.backend = backend
|
|
32
|
+
self.capabilities = capabilities
|
|
33
|
+
|
|
34
|
+
def run(self):
|
|
35
|
+
while True:
|
|
36
|
+
req = read()
|
|
37
|
+
if req is None:
|
|
38
|
+
break
|
|
39
|
+
self._dispatch(req)
|
|
40
|
+
|
|
41
|
+
def _error_response(self, message: str) -> None:
|
|
42
|
+
sys.stderr.write(f"Error: {message}\n")
|
|
43
|
+
print(json.dumps({"error": message}), end='\0', flush=True)
|
|
44
|
+
|
|
45
|
+
def _dispatch(self, req: dict):
|
|
46
|
+
method = req.get('method')
|
|
47
|
+
if not method:
|
|
48
|
+
self._error_response("'method' field is required")
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
if method == 'capabilities':
|
|
53
|
+
handle_capabilities(self.capabilities)
|
|
54
|
+
|
|
55
|
+
elif method == 'format_test':
|
|
56
|
+
messages = req.get('messages')
|
|
57
|
+
if not messages:
|
|
58
|
+
self._error_response("'messages' field is required for format_test method")
|
|
59
|
+
return
|
|
60
|
+
handle_format_test(self.backend, self.capabilities, messages, req.get('options', {}), req.get('tools'))
|
|
61
|
+
|
|
62
|
+
elif method == 'tokenize':
|
|
63
|
+
messages = req.get('messages')
|
|
64
|
+
if messages is None:
|
|
65
|
+
self._error_response("'messages' field is required for tokenize method")
|
|
66
|
+
return
|
|
67
|
+
handle_tokenize(
|
|
68
|
+
self.backend, self.capabilities, messages,
|
|
69
|
+
tools=req.get('tools'),
|
|
70
|
+
reasoning_effort=req.get('reasoning_effort'),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
elif method == 'cache_prefill':
|
|
74
|
+
cache_path = req.get('cache_path')
|
|
75
|
+
messages = req.get('messages')
|
|
76
|
+
if not cache_path or not messages:
|
|
77
|
+
self._error_response("'cache_path' and 'messages' fields are required for cache_prefill")
|
|
78
|
+
return
|
|
79
|
+
handle_cache_prefill(
|
|
80
|
+
self.backend, self.capabilities, cache_path, messages,
|
|
81
|
+
base_cache_path=req.get('base_cache_path'),
|
|
82
|
+
trim_to_tokens=req.get('trim_to_tokens'),
|
|
83
|
+
prefix_offsets=req.get('prefix_offsets'),
|
|
84
|
+
prefix_hashes=req.get('prefix_hashes'),
|
|
85
|
+
tools=req.get('tools'),
|
|
86
|
+
reasoning_effort=req.get('reasoning_effort'),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
elif method == 'chat':
|
|
90
|
+
messages = req.get('messages')
|
|
91
|
+
if not messages:
|
|
92
|
+
self._error_response("'messages' field is required for chat method")
|
|
93
|
+
return
|
|
94
|
+
handle_chat(
|
|
95
|
+
self.backend,
|
|
96
|
+
self.capabilities,
|
|
97
|
+
messages,
|
|
98
|
+
primer=req.get('primer'),
|
|
99
|
+
options=req.get('options', {}),
|
|
100
|
+
tools=req.get('tools'),
|
|
101
|
+
images=req.get('images', []),
|
|
102
|
+
max_image_size=req.get('maxImageSize', 768),
|
|
103
|
+
reasoning_effort=req.get('reasoning_effort'),
|
|
104
|
+
cache_path=req.get('cache_path'),
|
|
105
|
+
cache_trim_tokens=req.get('cache_trim_tokens'),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
elif method == 'completion':
|
|
109
|
+
prompt = req.get('prompt')
|
|
110
|
+
if not prompt:
|
|
111
|
+
self._error_response("'prompt' field is required for completion method")
|
|
112
|
+
return
|
|
113
|
+
images = req.get('images', [])
|
|
114
|
+
handle_completion(
|
|
115
|
+
self.backend,
|
|
116
|
+
prompt,
|
|
117
|
+
options=req.get('options', {}),
|
|
118
|
+
images=images if images else None,
|
|
119
|
+
max_image_size=req.get('maxImageSize', 768),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
else:
|
|
123
|
+
self._error_response(f"Unknown method '{method}'")
|
|
124
|
+
|
|
125
|
+
except Exception as e:
|
|
126
|
+
self._error_response(f"Error processing request: {e}")
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""プロンプト生成ユーティリティ"""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def supports_chat_template(tokenizer) -> bool:
|
|
5
|
+
return (hasattr(tokenizer, 'apply_chat_template') and
|
|
6
|
+
hasattr(tokenizer, 'chat_template') and
|
|
7
|
+
tokenizer.chat_template is not None)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def generate_merged_prompt(messages, capabilities):
|
|
11
|
+
"""apply_chat_templateがない場合のプロンプト生成"""
|
|
12
|
+
prompt_parts = []
|
|
13
|
+
special_tokens = capabilities.get('special_tokens', {})
|
|
14
|
+
|
|
15
|
+
for msg in messages:
|
|
16
|
+
role = msg['role']
|
|
17
|
+
role_upper = role.upper()
|
|
18
|
+
|
|
19
|
+
role_token = special_tokens.get(role)
|
|
20
|
+
|
|
21
|
+
if role_token and isinstance(role_token, dict) and 'start' in role_token:
|
|
22
|
+
start_token = role_token['start']['text']
|
|
23
|
+
end_token = role_token['end']['text']
|
|
24
|
+
prompt_parts.extend([
|
|
25
|
+
start_token,
|
|
26
|
+
msg['content'].strip(),
|
|
27
|
+
end_token,
|
|
28
|
+
''
|
|
29
|
+
])
|
|
30
|
+
else:
|
|
31
|
+
block_token = None
|
|
32
|
+
for candidate in ['block', 'context', 'quote', 'section']:
|
|
33
|
+
token = special_tokens.get(candidate)
|
|
34
|
+
if token and isinstance(token, dict) and 'start' in token:
|
|
35
|
+
block_token = token
|
|
36
|
+
break
|
|
37
|
+
|
|
38
|
+
if block_token:
|
|
39
|
+
start_token = block_token['start']['text']
|
|
40
|
+
end_token = block_token['end']['text']
|
|
41
|
+
prompt_parts.extend([
|
|
42
|
+
f'{start_token}{role_upper}:\n{msg["content"].strip()}',
|
|
43
|
+
end_token,
|
|
44
|
+
''
|
|
45
|
+
])
|
|
46
|
+
else:
|
|
47
|
+
prompt_parts.extend([
|
|
48
|
+
f'<!-- begin of {role_upper} -->',
|
|
49
|
+
msg['content'].strip(),
|
|
50
|
+
f'<!-- end of {role_upper} -->',
|
|
51
|
+
''
|
|
52
|
+
])
|
|
53
|
+
|
|
54
|
+
return '\n'.join(prompt_parts[:-1])
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
トークン関連のユーティリティ関数
|
|
3
3
|
"""
|
|
4
|
-
import
|
|
5
|
-
from chat_template_constraints import detect_chat_restrictions
|
|
4
|
+
from utils.chat_template_constraints import detect_chat_restrictions
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
def is_eod_token(response, tokenizer):
|
|
@@ -188,50 +187,32 @@ def get_special_tokens(tokenizer):
|
|
|
188
187
|
def detect_tool_call_format(tokenizer):
|
|
189
188
|
"""tokenizer設定からtool call/resultのデリミタを検出する
|
|
190
189
|
|
|
191
|
-
|
|
192
|
-
- tool_parser_type:
|
|
193
|
-
- chat_template
|
|
190
|
+
tokenizerアクセスが必要な情報のみを抽出する:
|
|
191
|
+
- tool_parser_type: tokenizer_config由来のパーサー種別文字列
|
|
192
|
+
- chat_templateテキストからのデリミタパターン検出
|
|
193
|
+
|
|
194
|
+
パーサー種別→デリミタのマッピングはTS側(detector.ts)に一元化。
|
|
195
|
+
Python側はtokenizerから生の情報を抽出して渡す役割に専念する。
|
|
194
196
|
|
|
195
197
|
Returns:
|
|
196
198
|
dict or None: {
|
|
197
199
|
"tool_parser_type": str, # tokenizer_configのtool_parser_type
|
|
198
|
-
"call_start": str, #
|
|
199
|
-
"call_end": str, #
|
|
200
|
+
"call_start": str, # chat_templateから検出した開始デリミタ
|
|
201
|
+
"call_end": str, # chat_templateから検出した終了デリミタ
|
|
200
202
|
"response_start": str, # tool responseの開始デリミタ(検出時)
|
|
201
203
|
"response_end": str, # tool responseの終了デリミタ(検出時)
|
|
202
204
|
} or None
|
|
203
205
|
"""
|
|
204
206
|
import re
|
|
205
207
|
|
|
206
|
-
# tool_parser_type を取得
|
|
207
208
|
tool_parser_type = None
|
|
208
209
|
if hasattr(tokenizer, 'init_kwargs'):
|
|
209
210
|
tool_parser_type = tokenizer.init_kwargs.get('tool_parser_type')
|
|
210
211
|
|
|
211
|
-
# 既知パーサーからの逆引き(最優先)
|
|
212
|
-
KNOWN_TOOL_PARSERS = {
|
|
213
|
-
"json_tools": {"call_start": "<tool_call>", "call_end": "</tool_call>"},
|
|
214
|
-
"pythonic": {"call_start": "<|tool_call_start|>", "call_end": "<|tool_call_end|>"},
|
|
215
|
-
"function_gemma": {"call_start": "<start_function_call>", "call_end": "<end_function_call>"},
|
|
216
|
-
"mistral": {"call_start": "[TOOL_CALLS]", "call_end": ""},
|
|
217
|
-
"kimi_k2": {"call_start": "<|tool_calls_section_begin|>", "call_end": "<|tool_calls_section_end|>"},
|
|
218
|
-
"longcat": {"call_start": "<longcat_tool_call>", "call_end": "</longcat_tool_call>"},
|
|
219
|
-
"glm47": {"call_start": "<tool_call>", "call_end": "</tool_call>"},
|
|
220
|
-
"qwen3_coder": {"call_start": "<tool_call>", "call_end": "</tool_call>"},
|
|
221
|
-
"minimax_m2": {"call_start": "<minimax:tool_call>", "call_end": "</minimax:tool_call>"},
|
|
222
|
-
}
|
|
223
|
-
|
|
224
|
-
if tool_parser_type and tool_parser_type in KNOWN_TOOL_PARSERS:
|
|
225
|
-
result = {"tool_parser_type": tool_parser_type}
|
|
226
|
-
result.update(KNOWN_TOOL_PARSERS[tool_parser_type])
|
|
227
|
-
return result
|
|
228
|
-
|
|
229
|
-
# chat_template テキストを取得
|
|
230
212
|
template = getattr(tokenizer, 'chat_template', None)
|
|
231
213
|
if not template and hasattr(tokenizer, 'init_kwargs'):
|
|
232
214
|
template = tokenizer.init_kwargs.get('chat_template', '')
|
|
233
215
|
|
|
234
|
-
# tool_parser_type もテンプレートもなければ非対応
|
|
235
216
|
if not tool_parser_type and not template:
|
|
236
217
|
return None
|
|
237
218
|
|
|
@@ -239,21 +220,13 @@ def detect_tool_call_format(tokenizer):
|
|
|
239
220
|
if tool_parser_type:
|
|
240
221
|
result["tool_parser_type"] = tool_parser_type
|
|
241
222
|
|
|
242
|
-
# テンプレートテキストからデリミタを抽出
|
|
243
223
|
if template:
|
|
244
|
-
# 複数のtool_call関連パターンを順に試行
|
|
245
224
|
tool_call_patterns = [
|
|
246
|
-
# </tool_call>, <|/tool_call|>, <tool_call|> (終了タグ専用)
|
|
247
225
|
(r'<\|?tool_call\|?>', r'</tool_call>|<\|/tool_call\|>|<tool_call\|>'),
|
|
248
|
-
# <|tool_call_start|>...<|tool_call_end|>
|
|
249
226
|
(r'<\|tool_call_start\|>', r'<\|tool_call_end\|>'),
|
|
250
|
-
# <start_function_call>...<end_function_call>
|
|
251
227
|
(r'<start_function_call>', r'<end_function_call>'),
|
|
252
|
-
# <|tool_calls_section_begin|>...<|tool_calls_section_end|>
|
|
253
228
|
(r'<\|tool_calls_section_begin\|>', r'<\|tool_calls_section_end\|>'),
|
|
254
|
-
# <longcat_tool_call>...</longcat_tool_call>
|
|
255
229
|
(r'<longcat_tool_call>', r'</longcat_tool_call>'),
|
|
256
|
-
# <minimax:tool_call>...</minimax:tool_call>
|
|
257
230
|
(r'<minimax:tool_call>', r'</minimax:tool_call>'),
|
|
258
231
|
]
|
|
259
232
|
|
|
@@ -265,8 +238,6 @@ def detect_tool_call_format(tokenizer):
|
|
|
265
238
|
result["call_end"] = end_match.group(0)
|
|
266
239
|
break
|
|
267
240
|
|
|
268
|
-
# Harmony形式の専用検出
|
|
269
|
-
# テンプレート内で "functions." と <|call|> が共存する場合
|
|
270
241
|
if "call_start" not in result:
|
|
271
242
|
has_functions = re.search(r'"functions\."', template)
|
|
272
243
|
has_call = re.search(r'<\|call\|>', template)
|
|
@@ -275,14 +246,12 @@ def detect_tool_call_format(tokenizer):
|
|
|
275
246
|
result["call_start"] = "to=functions."
|
|
276
247
|
result["call_end"] = "<|call|>"
|
|
277
248
|
|
|
278
|
-
# Mistral特殊ケース
|
|
279
249
|
if "call_start" not in result:
|
|
280
250
|
mistral_match = re.search(r'\[TOOL_CALLS\]', template)
|
|
281
251
|
if mistral_match:
|
|
282
252
|
result["call_start"] = "[TOOL_CALLS]"
|
|
283
253
|
result["call_end"] = ""
|
|
284
254
|
|
|
285
|
-
# tool_response タグの検出
|
|
286
255
|
resp_tags = re.findall(r'<[|/]?tool_response[|]?>', template)
|
|
287
256
|
if len(resp_tags) >= 2:
|
|
288
257
|
open_tags = [t for t in resp_tags if '/' not in t]
|