openhands-sdk 1.8.1__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openhands/sdk/agent/agent.py +64 -0
- openhands/sdk/agent/base.py +29 -10
- openhands/sdk/agent/prompts/system_prompt.j2 +1 -0
- openhands/sdk/context/condenser/llm_summarizing_condenser.py +7 -5
- openhands/sdk/context/skills/skill.py +59 -1
- openhands/sdk/context/skills/utils.py +6 -65
- openhands/sdk/context/view.py +6 -11
- openhands/sdk/conversation/base.py +5 -0
- openhands/sdk/conversation/event_store.py +84 -12
- openhands/sdk/conversation/impl/local_conversation.py +7 -0
- openhands/sdk/conversation/impl/remote_conversation.py +16 -3
- openhands/sdk/conversation/state.py +25 -2
- openhands/sdk/conversation/visualizer/base.py +23 -0
- openhands/sdk/critic/__init__.py +4 -1
- openhands/sdk/critic/base.py +17 -20
- openhands/sdk/critic/impl/__init__.py +2 -0
- openhands/sdk/critic/impl/agent_finished.py +9 -5
- openhands/sdk/critic/impl/api/__init__.py +18 -0
- openhands/sdk/critic/impl/api/chat_template.py +232 -0
- openhands/sdk/critic/impl/api/client.py +313 -0
- openhands/sdk/critic/impl/api/critic.py +90 -0
- openhands/sdk/critic/impl/api/taxonomy.py +180 -0
- openhands/sdk/critic/result.py +148 -0
- openhands/sdk/event/conversation_error.py +12 -0
- openhands/sdk/event/llm_convertible/action.py +10 -0
- openhands/sdk/event/llm_convertible/message.py +10 -0
- openhands/sdk/git/cached_repo.py +459 -0
- openhands/sdk/git/utils.py +118 -3
- openhands/sdk/hooks/__init__.py +7 -1
- openhands/sdk/hooks/config.py +154 -45
- openhands/sdk/io/base.py +52 -0
- openhands/sdk/io/local.py +25 -0
- openhands/sdk/io/memory.py +34 -1
- openhands/sdk/llm/llm.py +6 -2
- openhands/sdk/llm/utils/model_features.py +3 -0
- openhands/sdk/llm/utils/telemetry.py +41 -2
- openhands/sdk/plugin/__init__.py +17 -0
- openhands/sdk/plugin/fetch.py +231 -0
- openhands/sdk/plugin/plugin.py +61 -4
- openhands/sdk/plugin/types.py +394 -1
- openhands/sdk/secret/secrets.py +19 -4
- {openhands_sdk-1.8.1.dist-info → openhands_sdk-1.9.0.dist-info}/METADATA +6 -1
- {openhands_sdk-1.8.1.dist-info → openhands_sdk-1.9.0.dist-info}/RECORD +45 -37
- {openhands_sdk-1.8.1.dist-info → openhands_sdk-1.9.0.dist-info}/WHEEL +1 -1
- {openhands_sdk-1.8.1.dist-info → openhands_sdk-1.9.0.dist-info}/top_level.txt +0 -0
|
@@ -23,6 +23,7 @@ from openhands.sdk.security.confirmation_policy import (
|
|
|
23
23
|
ConfirmationPolicyBase,
|
|
24
24
|
NeverConfirm,
|
|
25
25
|
)
|
|
26
|
+
from openhands.sdk.utils.cipher import Cipher
|
|
26
27
|
from openhands.sdk.utils.models import OpenHandsModel
|
|
27
28
|
from openhands.sdk.workspace.base import BaseWorkspace
|
|
28
29
|
|
|
@@ -124,6 +125,7 @@ class ConversationState(OpenHandsModel):
|
|
|
124
125
|
# ===== Private attrs (NOT Fields) =====
|
|
125
126
|
_fs: FileStore = PrivateAttr() # filestore for persistence
|
|
126
127
|
_events: EventLog = PrivateAttr() # now the storage for events
|
|
128
|
+
_cipher: Cipher | None = PrivateAttr(default=None) # cipher for secret encryption
|
|
127
129
|
_autosave_enabled: bool = PrivateAttr(
|
|
128
130
|
default=False
|
|
129
131
|
) # to avoid recursion during init
|
|
@@ -166,8 +168,20 @@ class ConversationState(OpenHandsModel):
|
|
|
166
168
|
def _save_base_state(self, fs: FileStore) -> None:
|
|
167
169
|
"""
|
|
168
170
|
Persist base state snapshot (no events; events are file-backed).
|
|
171
|
+
|
|
172
|
+
If a cipher is configured, secrets will be encrypted. Otherwise, they
|
|
173
|
+
will be redacted (serialized as '**********').
|
|
169
174
|
"""
|
|
170
|
-
|
|
175
|
+
context = {"cipher": self._cipher} if self._cipher else None
|
|
176
|
+
# Warn if secrets exist but no cipher is configured
|
|
177
|
+
if not self._cipher and self.secret_registry.secret_sources:
|
|
178
|
+
logger.warning(
|
|
179
|
+
f"Saving conversation state without cipher - "
|
|
180
|
+
f"{len(self.secret_registry.secret_sources)} secret(s) will be "
|
|
181
|
+
"redacted and lost on restore. Consider providing a cipher to "
|
|
182
|
+
"preserve secrets."
|
|
183
|
+
)
|
|
184
|
+
payload = self.model_dump_json(exclude_none=True, context=context)
|
|
171
185
|
fs.write(BASE_STATE, payload)
|
|
172
186
|
|
|
173
187
|
# ===== Factory: open-or-create (no load/save methods needed) =====
|
|
@@ -180,6 +194,7 @@ class ConversationState(OpenHandsModel):
|
|
|
180
194
|
persistence_dir: str | None = None,
|
|
181
195
|
max_iterations: int = 500,
|
|
182
196
|
stuck_detection: bool = True,
|
|
197
|
+
cipher: Cipher | None = None,
|
|
183
198
|
) -> "ConversationState":
|
|
184
199
|
"""Create a new conversation state or resume from persistence.
|
|
185
200
|
|
|
@@ -203,6 +218,10 @@ class ConversationState(OpenHandsModel):
|
|
|
203
218
|
persistence_dir: Directory for persisting state and events
|
|
204
219
|
max_iterations: Maximum iterations per run
|
|
205
220
|
stuck_detection: Whether to enable stuck detection
|
|
221
|
+
cipher: Optional cipher for encrypting/decrypting secrets in
|
|
222
|
+
persisted state. If provided, secrets are encrypted when
|
|
223
|
+
saving and decrypted when loading. If not provided, secrets
|
|
224
|
+
are redacted (lost) on serialization.
|
|
206
225
|
|
|
207
226
|
Returns:
|
|
208
227
|
ConversationState ready for use
|
|
@@ -224,7 +243,9 @@ class ConversationState(OpenHandsModel):
|
|
|
224
243
|
|
|
225
244
|
# ---- Resume path ----
|
|
226
245
|
if base_text:
|
|
227
|
-
|
|
246
|
+
# Use cipher context for decrypting secrets if provided
|
|
247
|
+
context = {"cipher": cipher} if cipher else None
|
|
248
|
+
state = cls.model_validate(json.loads(base_text), context=context)
|
|
228
249
|
|
|
229
250
|
# Restore the conversation with the same id
|
|
230
251
|
if state.id != id:
|
|
@@ -236,6 +257,7 @@ class ConversationState(OpenHandsModel):
|
|
|
236
257
|
# Attach event log early so we can read history for tool verification
|
|
237
258
|
state._fs = file_store
|
|
238
259
|
state._events = EventLog(file_store, dir_path=EVENTS_DIR)
|
|
260
|
+
state._cipher = cipher
|
|
239
261
|
|
|
240
262
|
# Verify compatibility (agent class + tools)
|
|
241
263
|
agent.verify(state.agent, events=state._events)
|
|
@@ -272,6 +294,7 @@ class ConversationState(OpenHandsModel):
|
|
|
272
294
|
)
|
|
273
295
|
state._fs = file_store
|
|
274
296
|
state._events = EventLog(file_store, dir_path=EVENTS_DIR)
|
|
297
|
+
state._cipher = cipher
|
|
275
298
|
state.stats = ConversationStats()
|
|
276
299
|
|
|
277
300
|
state._save_base_state(file_store) # initial snapshot
|
|
@@ -65,3 +65,26 @@ class ConversationVisualizerBase(ABC):
|
|
|
65
65
|
event: The event to visualize
|
|
66
66
|
"""
|
|
67
67
|
pass
|
|
68
|
+
|
|
69
|
+
def create_sub_visualizer(
|
|
70
|
+
self,
|
|
71
|
+
agent_id: str, # noqa: ARG002
|
|
72
|
+
) -> "ConversationVisualizerBase | None":
|
|
73
|
+
"""Create a visualizer for a sub-agent during delegation.
|
|
74
|
+
|
|
75
|
+
Override this method to support sub-agent visualization in multi-agent
|
|
76
|
+
delegation scenarios. The sub-visualizer will be used to display events
|
|
77
|
+
from the spawned sub-agent.
|
|
78
|
+
|
|
79
|
+
By default, returns None which means sub-agents will not have visualization.
|
|
80
|
+
Subclasses that support delegation (like DelegationVisualizer) should
|
|
81
|
+
override this method to create appropriate sub-visualizers.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
agent_id: The identifier of the sub-agent being spawned
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
A visualizer instance for the sub-agent, or None if sub-agent
|
|
88
|
+
visualization is not supported
|
|
89
|
+
"""
|
|
90
|
+
return None
|
openhands/sdk/critic/__init__.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
|
1
|
-
from openhands.sdk.critic.base import CriticBase
|
|
1
|
+
from openhands.sdk.critic.base import CriticBase
|
|
2
2
|
from openhands.sdk.critic.impl import (
|
|
3
3
|
AgentFinishedCritic,
|
|
4
|
+
APIBasedCritic,
|
|
4
5
|
EmptyPatchCritic,
|
|
5
6
|
PassCritic,
|
|
6
7
|
)
|
|
8
|
+
from openhands.sdk.critic.result import CriticResult
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
__all__ = [
|
|
10
12
|
"CriticBase",
|
|
11
13
|
"CriticResult",
|
|
12
14
|
"AgentFinishedCritic",
|
|
15
|
+
"APIBasedCritic",
|
|
13
16
|
"EmptyPatchCritic",
|
|
14
17
|
"PassCritic",
|
|
15
18
|
]
|
openhands/sdk/critic/base.py
CHANGED
|
@@ -1,29 +1,15 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import TYPE_CHECKING, Literal
|
|
4
4
|
|
|
5
|
-
from pydantic import
|
|
5
|
+
from pydantic import Field
|
|
6
6
|
|
|
7
|
-
from openhands.sdk.
|
|
7
|
+
from openhands.sdk.critic.result import CriticResult
|
|
8
8
|
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
THRESHOLD: ClassVar[float] = 0.5
|
|
15
|
-
|
|
16
|
-
score: float = Field(
|
|
17
|
-
description="A predicted probability of success between 0 and 1.",
|
|
18
|
-
ge=0.0,
|
|
19
|
-
le=1.0,
|
|
20
|
-
)
|
|
21
|
-
message: str | None = Field(description="An optional message explaining the score.")
|
|
22
|
-
|
|
23
|
-
@property
|
|
24
|
-
def success(self) -> bool:
|
|
25
|
-
"""Whether the agent is successful."""
|
|
26
|
-
return self.score >= CriticResult.THRESHOLD
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from openhands.sdk.event.base import LLMConvertibleEvent
|
|
27
13
|
|
|
28
14
|
|
|
29
15
|
class CriticBase(DiscriminatedUnionMixin, abc.ABC):
|
|
@@ -31,8 +17,19 @@ class CriticBase(DiscriminatedUnionMixin, abc.ABC):
|
|
|
31
17
|
optional git patch, and returns a score about the quality of agent's action.
|
|
32
18
|
"""
|
|
33
19
|
|
|
20
|
+
mode: Literal["finish_and_message", "all_actions"] = Field(
|
|
21
|
+
default="finish_and_message",
|
|
22
|
+
description=(
|
|
23
|
+
"When to run critic evaluation:\n"
|
|
24
|
+
"- 'finish_and_message': Evaluate on FinishAction and agent"
|
|
25
|
+
" MessageEvent (default, minimal performance impact)\n"
|
|
26
|
+
"- 'all_actions': Evaluate after every agent action (WARNING: "
|
|
27
|
+
"significantly slower due to API calls on each action)"
|
|
28
|
+
),
|
|
29
|
+
)
|
|
30
|
+
|
|
34
31
|
@abc.abstractmethod
|
|
35
32
|
def evaluate(
|
|
36
|
-
self, events: Sequence[LLMConvertibleEvent], git_patch: str | None = None
|
|
33
|
+
self, events: Sequence["LLMConvertibleEvent"], git_patch: str | None = None
|
|
37
34
|
) -> CriticResult:
|
|
38
35
|
pass
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
"""Critic implementations module."""
|
|
2
2
|
|
|
3
3
|
from openhands.sdk.critic.impl.agent_finished import AgentFinishedCritic
|
|
4
|
+
from openhands.sdk.critic.impl.api import APIBasedCritic
|
|
4
5
|
from openhands.sdk.critic.impl.empty_patch import EmptyPatchCritic
|
|
5
6
|
from openhands.sdk.critic.impl.pass_critic import PassCritic
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"AgentFinishedCritic",
|
|
11
|
+
"APIBasedCritic",
|
|
10
12
|
"EmptyPatchCritic",
|
|
11
13
|
"PassCritic",
|
|
12
14
|
]
|
|
@@ -7,13 +7,17 @@ This critic evaluates whether an agent properly finished a task by checking:
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from collections.abc import Sequence
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
10
11
|
|
|
11
12
|
from openhands.sdk.critic.base import CriticBase, CriticResult
|
|
12
|
-
from openhands.sdk.event import ActionEvent, LLMConvertibleEvent
|
|
13
13
|
from openhands.sdk.logger import get_logger
|
|
14
14
|
from openhands.sdk.tool.builtins.finish import FinishAction
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from openhands.sdk.event.base import LLMConvertibleEvent
|
|
19
|
+
|
|
20
|
+
|
|
17
21
|
logger = get_logger(__name__)
|
|
18
22
|
|
|
19
23
|
|
|
@@ -27,7 +31,7 @@ class AgentFinishedCritic(CriticBase):
|
|
|
27
31
|
"""
|
|
28
32
|
|
|
29
33
|
def evaluate(
|
|
30
|
-
self, events: Sequence[LLMConvertibleEvent], git_patch: str | None = None
|
|
34
|
+
self, events: Sequence["LLMConvertibleEvent"], git_patch: str | None = None
|
|
31
35
|
) -> CriticResult:
|
|
32
36
|
"""
|
|
33
37
|
Evaluate if an agent properly finished with a non-empty git patch.
|
|
@@ -66,18 +70,18 @@ class AgentFinishedCritic(CriticBase):
|
|
|
66
70
|
message="Agent completed with FinishAction and non-empty patch",
|
|
67
71
|
)
|
|
68
72
|
|
|
69
|
-
def _has_finish_action(self, events: Sequence[LLMConvertibleEvent]) -> bool:
|
|
73
|
+
def _has_finish_action(self, events: Sequence["LLMConvertibleEvent"]) -> bool:
|
|
70
74
|
"""Check if the last action was a FinishAction."""
|
|
71
75
|
if not events:
|
|
72
76
|
return False
|
|
73
77
|
|
|
74
78
|
# Look for the last ActionEvent in the history
|
|
79
|
+
from openhands.sdk.event.llm_convertible.action import ActionEvent
|
|
80
|
+
|
|
75
81
|
for event in reversed(events):
|
|
76
82
|
if isinstance(event, ActionEvent):
|
|
77
|
-
# Check if this is a FinishAction
|
|
78
83
|
if event.action and isinstance(event.action, FinishAction):
|
|
79
84
|
return True
|
|
80
|
-
# If we find any other action type, the agent didn't finish
|
|
81
85
|
return False
|
|
82
86
|
|
|
83
87
|
return False
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from openhands.sdk.critic.impl.api.client import (
|
|
2
|
+
ClassificationItem,
|
|
3
|
+
ClassificationResponse,
|
|
4
|
+
CriticClient,
|
|
5
|
+
LabelProbMap,
|
|
6
|
+
UsageTokens,
|
|
7
|
+
)
|
|
8
|
+
from openhands.sdk.critic.impl.api.critic import APIBasedCritic
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"APIBasedCritic",
|
|
13
|
+
"CriticClient",
|
|
14
|
+
"ClassificationItem",
|
|
15
|
+
"ClassificationResponse",
|
|
16
|
+
"LabelProbMap",
|
|
17
|
+
"UsageTokens",
|
|
18
|
+
]
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Standalone chat template implementation using Jinja2.
|
|
3
|
+
|
|
4
|
+
This module provides a lightweight implementation of chat template rendering
|
|
5
|
+
that is compatible with HuggingFace transformers but removes the dependency
|
|
6
|
+
on the full transformers library.
|
|
7
|
+
|
|
8
|
+
The implementation follows the same approach as transformers:
|
|
9
|
+
- Uses Jinja2 for template rendering
|
|
10
|
+
- Loads templates dynamically from tokenizer_config.json
|
|
11
|
+
- Supports caching of compiled templates and fetched configs
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import hashlib
|
|
17
|
+
import json
|
|
18
|
+
from collections.abc import Sequence
|
|
19
|
+
from functools import lru_cache
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
from urllib.error import URLError
|
|
23
|
+
from urllib.request import Request, urlopen
|
|
24
|
+
|
|
25
|
+
import jinja2
|
|
26
|
+
from jinja2.ext import loopcontrols
|
|
27
|
+
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Cache directory for downloaded tokenizer configs
|
|
31
|
+
CACHE_DIR = Path.home() / ".cache" / "chat_templates"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_cache_path(tokenizer_name: str) -> Path:
|
|
35
|
+
"""Get the cache path for a tokenizer config."""
|
|
36
|
+
# Create a safe filename from the tokenizer name
|
|
37
|
+
safe_name = hashlib.md5(tokenizer_name.encode()).hexdigest()
|
|
38
|
+
return CACHE_DIR / f"{safe_name}_tokenizer_config.json"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _fetch_tokenizer_config(
|
|
42
|
+
tokenizer_name: str, use_cache: bool = True
|
|
43
|
+
) -> dict[str, Any]:
|
|
44
|
+
"""
|
|
45
|
+
Fetch tokenizer_config.json from HuggingFace Hub.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
tokenizer_name: The HuggingFace model/tokenizer name
|
|
49
|
+
(e.g., "Qwen/Qwen3-4B-Instruct-2507")
|
|
50
|
+
use_cache: Whether to use cached config if available
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
The parsed tokenizer config dictionary
|
|
54
|
+
"""
|
|
55
|
+
cache_path = _get_cache_path(tokenizer_name)
|
|
56
|
+
|
|
57
|
+
# Try to load from cache
|
|
58
|
+
if use_cache and cache_path.exists():
|
|
59
|
+
with open(cache_path, encoding="utf-8") as f:
|
|
60
|
+
return json.load(f)
|
|
61
|
+
|
|
62
|
+
# Fetch from HuggingFace Hub
|
|
63
|
+
url = f"https://huggingface.co/{tokenizer_name}/raw/main/tokenizer_config.json"
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
request = Request(url, headers={"User-Agent": "chat_template/1.0"})
|
|
67
|
+
with urlopen(request, timeout=30) as response:
|
|
68
|
+
config = json.loads(response.read().decode("utf-8"))
|
|
69
|
+
except URLError as e:
|
|
70
|
+
raise RuntimeError(f"Failed to fetch tokenizer config from {url}: {e}")
|
|
71
|
+
|
|
72
|
+
# Cache the config
|
|
73
|
+
if use_cache:
|
|
74
|
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
with open(cache_path, "w", encoding="utf-8") as f:
|
|
76
|
+
json.dump(config, f)
|
|
77
|
+
|
|
78
|
+
return config
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@lru_cache(maxsize=16)
|
|
82
|
+
def _compile_jinja_template(chat_template: str) -> jinja2.Template:
|
|
83
|
+
"""
|
|
84
|
+
Compile a Jinja2 chat template.
|
|
85
|
+
|
|
86
|
+
This matches the transformers implementation with custom tojson filter
|
|
87
|
+
and other utilities.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def raise_exception(message: str) -> None:
|
|
91
|
+
raise jinja2.exceptions.TemplateError(message)
|
|
92
|
+
|
|
93
|
+
def tojson(
|
|
94
|
+
x: Any,
|
|
95
|
+
ensure_ascii: bool = False,
|
|
96
|
+
indent: int | None = None,
|
|
97
|
+
separators: tuple[str, str] | None = None,
|
|
98
|
+
sort_keys: bool = False,
|
|
99
|
+
) -> str:
|
|
100
|
+
# Match the transformers implementation - no HTML escaping
|
|
101
|
+
return json.dumps(
|
|
102
|
+
x,
|
|
103
|
+
ensure_ascii=ensure_ascii,
|
|
104
|
+
indent=indent,
|
|
105
|
+
separators=separators,
|
|
106
|
+
sort_keys=sort_keys,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
jinja_env = ImmutableSandboxedEnvironment(
|
|
110
|
+
trim_blocks=True,
|
|
111
|
+
lstrip_blocks=True,
|
|
112
|
+
extensions=[loopcontrols],
|
|
113
|
+
)
|
|
114
|
+
jinja_env.filters["tojson"] = tojson
|
|
115
|
+
jinja_env.globals["raise_exception"] = raise_exception
|
|
116
|
+
|
|
117
|
+
return jinja_env.from_string(chat_template)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ChatTemplateRenderer:
|
|
121
|
+
"""
|
|
122
|
+
A lightweight chat template renderer compatible with HuggingFace transformers.
|
|
123
|
+
|
|
124
|
+
This class can dynamically load templates from HuggingFace Hub or use
|
|
125
|
+
provided templates directly.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
tokenizer_name: str | None = None,
|
|
131
|
+
chat_template: str | None = None,
|
|
132
|
+
use_cache: bool = True,
|
|
133
|
+
):
|
|
134
|
+
"""
|
|
135
|
+
Initialize the renderer.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
tokenizer_name: HuggingFace tokenizer name to load template from.
|
|
139
|
+
If provided, will fetch tokenizer_config.json from
|
|
140
|
+
HuggingFace Hub.
|
|
141
|
+
chat_template: Direct Jinja2 template string.
|
|
142
|
+
If provided, tokenizer_name is ignored.
|
|
143
|
+
use_cache: Whether to cache fetched tokenizer configs.
|
|
144
|
+
"""
|
|
145
|
+
if chat_template is not None:
|
|
146
|
+
self._chat_template = chat_template
|
|
147
|
+
elif tokenizer_name is not None:
|
|
148
|
+
config = _fetch_tokenizer_config(tokenizer_name, use_cache=use_cache)
|
|
149
|
+
self._chat_template = config.get("chat_template")
|
|
150
|
+
if self._chat_template is None:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"No chat_template found in tokenizer config for {tokenizer_name}"
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError("Either tokenizer_name or chat_template must be provided")
|
|
156
|
+
|
|
157
|
+
self._compiled_template = _compile_jinja_template(self._chat_template)
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def chat_template(self) -> str:
|
|
161
|
+
"""The raw Jinja2 chat template string."""
|
|
162
|
+
assert self._chat_template is not None
|
|
163
|
+
return self._chat_template
|
|
164
|
+
|
|
165
|
+
def apply_chat_template(
|
|
166
|
+
self,
|
|
167
|
+
messages: Sequence[dict[str, Any]],
|
|
168
|
+
tools: Sequence[dict[str, Any]] | None = None,
|
|
169
|
+
add_generation_prompt: bool = False,
|
|
170
|
+
**kwargs: Any,
|
|
171
|
+
) -> str:
|
|
172
|
+
"""
|
|
173
|
+
Apply the chat template to format messages.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
messages: List of message dicts with 'role' and 'content' keys.
|
|
177
|
+
tools: Optional list of tool definitions for function calling.
|
|
178
|
+
add_generation_prompt: If True, append assistant prompt at the end.
|
|
179
|
+
**kwargs: Additional template variables.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Formatted string ready for tokenization.
|
|
183
|
+
"""
|
|
184
|
+
return self._compiled_template.render(
|
|
185
|
+
messages=messages,
|
|
186
|
+
tools=tools,
|
|
187
|
+
add_generation_prompt=add_generation_prompt,
|
|
188
|
+
**kwargs,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# Convenience function for simple use cases
|
|
193
|
+
def apply_chat_template(
|
|
194
|
+
messages: Sequence[dict[str, Any]],
|
|
195
|
+
tokenizer_name: str | None = None,
|
|
196
|
+
chat_template: str | None = None,
|
|
197
|
+
tools: Sequence[dict[str, Any]] | None = None,
|
|
198
|
+
add_generation_prompt: bool = False,
|
|
199
|
+
use_cache: bool = True,
|
|
200
|
+
**kwargs: Any,
|
|
201
|
+
) -> str:
|
|
202
|
+
"""
|
|
203
|
+
Apply a chat template to format messages.
|
|
204
|
+
|
|
205
|
+
This is a convenience function that creates a renderer and applies the
|
|
206
|
+
template. For repeated use with the same tokenizer, prefer using
|
|
207
|
+
ChatTemplateRenderer directly.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
messages: List of message dicts with 'role' and 'content' keys.
|
|
211
|
+
tokenizer_name: HuggingFace tokenizer name to load template from.
|
|
212
|
+
chat_template: Direct Jinja2 template string.
|
|
213
|
+
If provided, tokenizer_name is ignored.
|
|
214
|
+
tools: Optional list of tool definitions for function calling.
|
|
215
|
+
add_generation_prompt: If True, append assistant prompt at the end.
|
|
216
|
+
use_cache: Whether to cache fetched tokenizer configs.
|
|
217
|
+
**kwargs: Additional template variables.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Formatted string ready for tokenization.
|
|
221
|
+
"""
|
|
222
|
+
renderer = ChatTemplateRenderer(
|
|
223
|
+
tokenizer_name=tokenizer_name,
|
|
224
|
+
chat_template=chat_template,
|
|
225
|
+
use_cache=use_cache,
|
|
226
|
+
)
|
|
227
|
+
return renderer.apply_chat_template(
|
|
228
|
+
messages=messages,
|
|
229
|
+
tools=tools,
|
|
230
|
+
add_generation_prompt=add_generation_prompt,
|
|
231
|
+
**kwargs,
|
|
232
|
+
)
|