openhands-sdk 1.8.2__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openhands/sdk/agent/agent.py +64 -0
- openhands/sdk/agent/base.py +22 -10
- openhands/sdk/context/skills/skill.py +59 -1
- openhands/sdk/context/skills/utils.py +6 -65
- openhands/sdk/conversation/base.py +5 -0
- openhands/sdk/conversation/impl/remote_conversation.py +16 -3
- openhands/sdk/conversation/visualizer/base.py +23 -0
- openhands/sdk/critic/__init__.py +4 -1
- openhands/sdk/critic/base.py +17 -20
- openhands/sdk/critic/impl/__init__.py +2 -0
- openhands/sdk/critic/impl/agent_finished.py +9 -5
- openhands/sdk/critic/impl/api/__init__.py +18 -0
- openhands/sdk/critic/impl/api/chat_template.py +232 -0
- openhands/sdk/critic/impl/api/client.py +313 -0
- openhands/sdk/critic/impl/api/critic.py +90 -0
- openhands/sdk/critic/impl/api/taxonomy.py +180 -0
- openhands/sdk/critic/result.py +148 -0
- openhands/sdk/event/llm_convertible/action.py +10 -0
- openhands/sdk/event/llm_convertible/message.py +10 -0
- openhands/sdk/git/cached_repo.py +459 -0
- openhands/sdk/git/utils.py +118 -3
- openhands/sdk/hooks/__init__.py +7 -1
- openhands/sdk/hooks/config.py +154 -45
- openhands/sdk/llm/utils/model_features.py +3 -0
- openhands/sdk/plugin/__init__.py +17 -0
- openhands/sdk/plugin/fetch.py +231 -0
- openhands/sdk/plugin/plugin.py +61 -4
- openhands/sdk/plugin/types.py +394 -1
- {openhands_sdk-1.8.2.dist-info → openhands_sdk-1.9.0.dist-info}/METADATA +5 -1
- {openhands_sdk-1.8.2.dist-info → openhands_sdk-1.9.0.dist-info}/RECORD +32 -24
- {openhands_sdk-1.8.2.dist-info → openhands_sdk-1.9.0.dist-info}/WHEEL +1 -1
- {openhands_sdk-1.8.2.dist-info → openhands_sdk-1.9.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Standalone chat template implementation using Jinja2.
|
|
3
|
+
|
|
4
|
+
This module provides a lightweight implementation of chat template rendering
|
|
5
|
+
that is compatible with HuggingFace transformers but removes the dependency
|
|
6
|
+
on the full transformers library.
|
|
7
|
+
|
|
8
|
+
The implementation follows the same approach as transformers:
|
|
9
|
+
- Uses Jinja2 for template rendering
|
|
10
|
+
- Loads templates dynamically from tokenizer_config.json
|
|
11
|
+
- Supports caching of compiled templates and fetched configs
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import hashlib
|
|
17
|
+
import json
|
|
18
|
+
from collections.abc import Sequence
|
|
19
|
+
from functools import lru_cache
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
from urllib.error import URLError
|
|
23
|
+
from urllib.request import Request, urlopen
|
|
24
|
+
|
|
25
|
+
import jinja2
|
|
26
|
+
from jinja2.ext import loopcontrols
|
|
27
|
+
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Cache directory for downloaded tokenizer configs
|
|
31
|
+
CACHE_DIR = Path.home() / ".cache" / "chat_templates"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_cache_path(tokenizer_name: str) -> Path:
|
|
35
|
+
"""Get the cache path for a tokenizer config."""
|
|
36
|
+
# Create a safe filename from the tokenizer name
|
|
37
|
+
safe_name = hashlib.md5(tokenizer_name.encode()).hexdigest()
|
|
38
|
+
return CACHE_DIR / f"{safe_name}_tokenizer_config.json"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _fetch_tokenizer_config(
|
|
42
|
+
tokenizer_name: str, use_cache: bool = True
|
|
43
|
+
) -> dict[str, Any]:
|
|
44
|
+
"""
|
|
45
|
+
Fetch tokenizer_config.json from HuggingFace Hub.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
tokenizer_name: The HuggingFace model/tokenizer name
|
|
49
|
+
(e.g., "Qwen/Qwen3-4B-Instruct-2507")
|
|
50
|
+
use_cache: Whether to use cached config if available
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
The parsed tokenizer config dictionary
|
|
54
|
+
"""
|
|
55
|
+
cache_path = _get_cache_path(tokenizer_name)
|
|
56
|
+
|
|
57
|
+
# Try to load from cache
|
|
58
|
+
if use_cache and cache_path.exists():
|
|
59
|
+
with open(cache_path, encoding="utf-8") as f:
|
|
60
|
+
return json.load(f)
|
|
61
|
+
|
|
62
|
+
# Fetch from HuggingFace Hub
|
|
63
|
+
url = f"https://huggingface.co/{tokenizer_name}/raw/main/tokenizer_config.json"
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
request = Request(url, headers={"User-Agent": "chat_template/1.0"})
|
|
67
|
+
with urlopen(request, timeout=30) as response:
|
|
68
|
+
config = json.loads(response.read().decode("utf-8"))
|
|
69
|
+
except URLError as e:
|
|
70
|
+
raise RuntimeError(f"Failed to fetch tokenizer config from {url}: {e}")
|
|
71
|
+
|
|
72
|
+
# Cache the config
|
|
73
|
+
if use_cache:
|
|
74
|
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
with open(cache_path, "w", encoding="utf-8") as f:
|
|
76
|
+
json.dump(config, f)
|
|
77
|
+
|
|
78
|
+
return config
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@lru_cache(maxsize=16)
|
|
82
|
+
def _compile_jinja_template(chat_template: str) -> jinja2.Template:
|
|
83
|
+
"""
|
|
84
|
+
Compile a Jinja2 chat template.
|
|
85
|
+
|
|
86
|
+
This matches the transformers implementation with custom tojson filter
|
|
87
|
+
and other utilities.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def raise_exception(message: str) -> None:
|
|
91
|
+
raise jinja2.exceptions.TemplateError(message)
|
|
92
|
+
|
|
93
|
+
def tojson(
|
|
94
|
+
x: Any,
|
|
95
|
+
ensure_ascii: bool = False,
|
|
96
|
+
indent: int | None = None,
|
|
97
|
+
separators: tuple[str, str] | None = None,
|
|
98
|
+
sort_keys: bool = False,
|
|
99
|
+
) -> str:
|
|
100
|
+
# Match the transformers implementation - no HTML escaping
|
|
101
|
+
return json.dumps(
|
|
102
|
+
x,
|
|
103
|
+
ensure_ascii=ensure_ascii,
|
|
104
|
+
indent=indent,
|
|
105
|
+
separators=separators,
|
|
106
|
+
sort_keys=sort_keys,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
jinja_env = ImmutableSandboxedEnvironment(
|
|
110
|
+
trim_blocks=True,
|
|
111
|
+
lstrip_blocks=True,
|
|
112
|
+
extensions=[loopcontrols],
|
|
113
|
+
)
|
|
114
|
+
jinja_env.filters["tojson"] = tojson
|
|
115
|
+
jinja_env.globals["raise_exception"] = raise_exception
|
|
116
|
+
|
|
117
|
+
return jinja_env.from_string(chat_template)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ChatTemplateRenderer:
|
|
121
|
+
"""
|
|
122
|
+
A lightweight chat template renderer compatible with HuggingFace transformers.
|
|
123
|
+
|
|
124
|
+
This class can dynamically load templates from HuggingFace Hub or use
|
|
125
|
+
provided templates directly.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
tokenizer_name: str | None = None,
|
|
131
|
+
chat_template: str | None = None,
|
|
132
|
+
use_cache: bool = True,
|
|
133
|
+
):
|
|
134
|
+
"""
|
|
135
|
+
Initialize the renderer.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
tokenizer_name: HuggingFace tokenizer name to load template from.
|
|
139
|
+
If provided, will fetch tokenizer_config.json from
|
|
140
|
+
HuggingFace Hub.
|
|
141
|
+
chat_template: Direct Jinja2 template string.
|
|
142
|
+
If provided, tokenizer_name is ignored.
|
|
143
|
+
use_cache: Whether to cache fetched tokenizer configs.
|
|
144
|
+
"""
|
|
145
|
+
if chat_template is not None:
|
|
146
|
+
self._chat_template = chat_template
|
|
147
|
+
elif tokenizer_name is not None:
|
|
148
|
+
config = _fetch_tokenizer_config(tokenizer_name, use_cache=use_cache)
|
|
149
|
+
self._chat_template = config.get("chat_template")
|
|
150
|
+
if self._chat_template is None:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"No chat_template found in tokenizer config for {tokenizer_name}"
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError("Either tokenizer_name or chat_template must be provided")
|
|
156
|
+
|
|
157
|
+
self._compiled_template = _compile_jinja_template(self._chat_template)
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def chat_template(self) -> str:
|
|
161
|
+
"""The raw Jinja2 chat template string."""
|
|
162
|
+
assert self._chat_template is not None
|
|
163
|
+
return self._chat_template
|
|
164
|
+
|
|
165
|
+
def apply_chat_template(
|
|
166
|
+
self,
|
|
167
|
+
messages: Sequence[dict[str, Any]],
|
|
168
|
+
tools: Sequence[dict[str, Any]] | None = None,
|
|
169
|
+
add_generation_prompt: bool = False,
|
|
170
|
+
**kwargs: Any,
|
|
171
|
+
) -> str:
|
|
172
|
+
"""
|
|
173
|
+
Apply the chat template to format messages.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
messages: List of message dicts with 'role' and 'content' keys.
|
|
177
|
+
tools: Optional list of tool definitions for function calling.
|
|
178
|
+
add_generation_prompt: If True, append assistant prompt at the end.
|
|
179
|
+
**kwargs: Additional template variables.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Formatted string ready for tokenization.
|
|
183
|
+
"""
|
|
184
|
+
return self._compiled_template.render(
|
|
185
|
+
messages=messages,
|
|
186
|
+
tools=tools,
|
|
187
|
+
add_generation_prompt=add_generation_prompt,
|
|
188
|
+
**kwargs,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# Convenience function for simple use cases
|
|
193
|
+
def apply_chat_template(
|
|
194
|
+
messages: Sequence[dict[str, Any]],
|
|
195
|
+
tokenizer_name: str | None = None,
|
|
196
|
+
chat_template: str | None = None,
|
|
197
|
+
tools: Sequence[dict[str, Any]] | None = None,
|
|
198
|
+
add_generation_prompt: bool = False,
|
|
199
|
+
use_cache: bool = True,
|
|
200
|
+
**kwargs: Any,
|
|
201
|
+
) -> str:
|
|
202
|
+
"""
|
|
203
|
+
Apply a chat template to format messages.
|
|
204
|
+
|
|
205
|
+
This is a convenience function that creates a renderer and applies the
|
|
206
|
+
template. For repeated use with the same tokenizer, prefer using
|
|
207
|
+
ChatTemplateRenderer directly.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
messages: List of message dicts with 'role' and 'content' keys.
|
|
211
|
+
tokenizer_name: HuggingFace tokenizer name to load template from.
|
|
212
|
+
chat_template: Direct Jinja2 template string.
|
|
213
|
+
If provided, tokenizer_name is ignored.
|
|
214
|
+
tools: Optional list of tool definitions for function calling.
|
|
215
|
+
add_generation_prompt: If True, append assistant prompt at the end.
|
|
216
|
+
use_cache: Whether to cache fetched tokenizer configs.
|
|
217
|
+
**kwargs: Additional template variables.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Formatted string ready for tokenization.
|
|
221
|
+
"""
|
|
222
|
+
renderer = ChatTemplateRenderer(
|
|
223
|
+
tokenizer_name=tokenizer_name,
|
|
224
|
+
chat_template=chat_template,
|
|
225
|
+
use_cache=use_cache,
|
|
226
|
+
)
|
|
227
|
+
return renderer.apply_chat_template(
|
|
228
|
+
messages=messages,
|
|
229
|
+
tools=tools,
|
|
230
|
+
add_generation_prompt=add_generation_prompt,
|
|
231
|
+
**kwargs,
|
|
232
|
+
)
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from litellm import ChatCompletionToolParam
|
|
7
|
+
from pydantic import (
|
|
8
|
+
BaseModel,
|
|
9
|
+
ConfigDict,
|
|
10
|
+
Field,
|
|
11
|
+
PrivateAttr,
|
|
12
|
+
SecretStr,
|
|
13
|
+
field_validator,
|
|
14
|
+
)
|
|
15
|
+
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
|
16
|
+
|
|
17
|
+
from .chat_template import ChatTemplateRenderer
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ============================================================
|
|
21
|
+
# Typed API response models
|
|
22
|
+
# ============================================================
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class UsageTokens(BaseModel):
|
|
26
|
+
prompt_tokens: int | None = None
|
|
27
|
+
total_tokens: int | None = None
|
|
28
|
+
completion_tokens: int | None = None
|
|
29
|
+
prompt_tokens_details: dict | None = None
|
|
30
|
+
model_config = ConfigDict(extra="allow")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ClassificationItem(BaseModel):
|
|
34
|
+
"""One per-label or flat classification result."""
|
|
35
|
+
|
|
36
|
+
index: int | None = None
|
|
37
|
+
label: str | None = None
|
|
38
|
+
probs: list[float]
|
|
39
|
+
num_classes: int | None = None
|
|
40
|
+
model_config = ConfigDict(extra="allow")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ClassificationResponse(BaseModel):
|
|
44
|
+
id: str | None = None
|
|
45
|
+
object: str | None = None
|
|
46
|
+
created: int | None = None
|
|
47
|
+
model: str | None = None
|
|
48
|
+
data: list[ClassificationItem] = Field(default_factory=list)
|
|
49
|
+
usage: UsageTokens | None = None
|
|
50
|
+
model_config = ConfigDict(extra="allow")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LabelProbMap(BaseModel):
|
|
54
|
+
"""Normalized probability map label -> value, with optional ordering."""
|
|
55
|
+
|
|
56
|
+
probs: dict[str, float] # {"label": probability}
|
|
57
|
+
order: list[str] | None = None # if you requested a specific order
|
|
58
|
+
model_config = ConfigDict(extra="forbid")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# ============================================================
|
|
62
|
+
# CriticClient
|
|
63
|
+
# ============================================================
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class CriticClient(BaseModel):
|
|
67
|
+
"""
|
|
68
|
+
Core inference client for the Critic classification service.
|
|
69
|
+
|
|
70
|
+
Owns:
|
|
71
|
+
- Configuration (server URL, API key, model, tokenizer, etc.)
|
|
72
|
+
- Label space (for predictions only)
|
|
73
|
+
- Message normalization and chat template formatting
|
|
74
|
+
- Inference via vLLM /classify endpoint
|
|
75
|
+
|
|
76
|
+
Does NOT handle:
|
|
77
|
+
- Dataset loading
|
|
78
|
+
- Ground truth extraction
|
|
79
|
+
- Evaluation / metrics
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
|
|
83
|
+
|
|
84
|
+
# --- connection / model config ---
|
|
85
|
+
server_url: str = Field(
|
|
86
|
+
default="https://all-hands-ai--critic-qwen3-4b-serve.modal.run",
|
|
87
|
+
description="Base URL of the vLLM classification service",
|
|
88
|
+
)
|
|
89
|
+
api_key: str | SecretStr = Field(
|
|
90
|
+
..., description="API key for authenticating with the vLLM service"
|
|
91
|
+
)
|
|
92
|
+
model_name: str = Field(
|
|
93
|
+
default="critic-qwen3-4b", description="Name of the model to use"
|
|
94
|
+
)
|
|
95
|
+
tokenizer_name: str = Field(
|
|
96
|
+
default="Qwen/Qwen3-4B-Instruct-2507",
|
|
97
|
+
description="HuggingFace tokenizer name for loading chat template",
|
|
98
|
+
)
|
|
99
|
+
pass_tools_definitions: bool = Field(
|
|
100
|
+
default=True, description="Whether to pass tool definitions to the model"
|
|
101
|
+
)
|
|
102
|
+
timeout_seconds: float = Field(
|
|
103
|
+
default=300.0, description="Timeout for requests to the model"
|
|
104
|
+
)
|
|
105
|
+
has_success_label: bool = Field(
|
|
106
|
+
default=True, description="Whether the model predicts success label at index 0"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# --- runtime fields ---
|
|
110
|
+
_client: httpx.Client = PrivateAttr(default_factory=httpx.Client)
|
|
111
|
+
_template_renderer: ChatTemplateRenderer | None = PrivateAttr(default=None)
|
|
112
|
+
|
|
113
|
+
# --- label space ---
|
|
114
|
+
sentiment_labels: tuple[str, ...] = (
|
|
115
|
+
"sentiment_positive",
|
|
116
|
+
"sentiment_neutral",
|
|
117
|
+
"sentiment_negative",
|
|
118
|
+
)
|
|
119
|
+
agent_issue_labels: tuple[str, ...] = (
|
|
120
|
+
"misunderstood_intention",
|
|
121
|
+
"did_not_follow_instruction",
|
|
122
|
+
"insufficient_analysis",
|
|
123
|
+
"insufficient_clarification",
|
|
124
|
+
"improper_tool_use_or_setup",
|
|
125
|
+
"loop_behavior",
|
|
126
|
+
"insufficient_testing",
|
|
127
|
+
"insufficient_debugging",
|
|
128
|
+
"incomplete_implementation",
|
|
129
|
+
"file_management_errors",
|
|
130
|
+
"scope_creep",
|
|
131
|
+
"risky_actions_or_permission",
|
|
132
|
+
"other_agent_issue",
|
|
133
|
+
)
|
|
134
|
+
infra_labels: tuple[str, ...] = (
|
|
135
|
+
"infrastructure_external_issue",
|
|
136
|
+
"infrastructure_agent_caused_issue",
|
|
137
|
+
)
|
|
138
|
+
user_followup_labels: tuple[str, ...] = (
|
|
139
|
+
"clarification_or_restatement",
|
|
140
|
+
"correction",
|
|
141
|
+
"direction_change",
|
|
142
|
+
"vcs_update_requests",
|
|
143
|
+
"progress_or_scope_concern",
|
|
144
|
+
"frustration_or_complaint",
|
|
145
|
+
"removal_or_reversion_request",
|
|
146
|
+
"other_user_issue",
|
|
147
|
+
)
|
|
148
|
+
sentiment_map: dict[str, str] = {
|
|
149
|
+
"Positive": "sentiment_positive",
|
|
150
|
+
"Neutral": "sentiment_neutral",
|
|
151
|
+
"Negative": "sentiment_negative",
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
# ---------------------
|
|
155
|
+
# Validation
|
|
156
|
+
# ---------------------
|
|
157
|
+
@field_validator("api_key", mode="before")
|
|
158
|
+
@classmethod
|
|
159
|
+
def _validate_and_convert_api_key(cls, v: str | SecretStr) -> SecretStr:
|
|
160
|
+
"""Convert str to SecretStr and validate non-empty."""
|
|
161
|
+
if isinstance(v, SecretStr):
|
|
162
|
+
secret_value = v.get_secret_value()
|
|
163
|
+
else:
|
|
164
|
+
secret_value = v
|
|
165
|
+
|
|
166
|
+
if not secret_value or not secret_value.strip():
|
|
167
|
+
raise ValueError("api_key must be non-empty")
|
|
168
|
+
|
|
169
|
+
return SecretStr(secret_value) if isinstance(v, str) else v
|
|
170
|
+
|
|
171
|
+
# ---------------------
|
|
172
|
+
# Label helpers
|
|
173
|
+
# ---------------------
|
|
174
|
+
@property
|
|
175
|
+
def all_labels(self) -> tuple[str, ...]:
|
|
176
|
+
base_labels = (
|
|
177
|
+
self.sentiment_labels
|
|
178
|
+
+ self.agent_issue_labels
|
|
179
|
+
+ self.infra_labels
|
|
180
|
+
+ self.user_followup_labels
|
|
181
|
+
)
|
|
182
|
+
if self.has_success_label:
|
|
183
|
+
return ("success",) + base_labels
|
|
184
|
+
return base_labels
|
|
185
|
+
|
|
186
|
+
# ---------------------
|
|
187
|
+
# Tokenizer / formatting
|
|
188
|
+
# ---------------------
|
|
189
|
+
def _get_template_renderer(self) -> ChatTemplateRenderer:
|
|
190
|
+
"""Lazily initialize the chat template renderer."""
|
|
191
|
+
if self._template_renderer is None:
|
|
192
|
+
self._template_renderer = ChatTemplateRenderer(
|
|
193
|
+
tokenizer_name=self.tokenizer_name
|
|
194
|
+
)
|
|
195
|
+
return self._template_renderer
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def normalize_messages(messages: Sequence[dict]) -> Sequence[dict]:
|
|
199
|
+
"""Ensure messages all have string content and flatten text blocks."""
|
|
200
|
+
out: list[dict] = []
|
|
201
|
+
for msg in messages or []:
|
|
202
|
+
content = msg.get("content", "") or ""
|
|
203
|
+
if isinstance(content, list):
|
|
204
|
+
text_parts = [
|
|
205
|
+
block.get("text", "")
|
|
206
|
+
for block in content
|
|
207
|
+
if isinstance(block, dict) and block.get("type") == "text"
|
|
208
|
+
]
|
|
209
|
+
content = "\n".join(text_parts)
|
|
210
|
+
if not isinstance(content, str):
|
|
211
|
+
content = str(content)
|
|
212
|
+
out.append({"role": msg.get("role", ""), "content": content})
|
|
213
|
+
return out
|
|
214
|
+
|
|
215
|
+
def apply_chat_template(
|
|
216
|
+
self,
|
|
217
|
+
messages: Sequence[dict],
|
|
218
|
+
tools: Sequence[ChatCompletionToolParam] | None = None,
|
|
219
|
+
) -> str:
|
|
220
|
+
renderer = self._get_template_renderer()
|
|
221
|
+
msgs = self.normalize_messages(copy.deepcopy(messages))
|
|
222
|
+
# Cast tools to Sequence[dict[str, Any]] for type compatibility
|
|
223
|
+
# ChatCompletionToolParam is a TypedDict which is structurally compatible
|
|
224
|
+
tools_dicts: Sequence[dict[str, Any]] | None = (
|
|
225
|
+
cast(Sequence[dict[str, Any]], tools) if tools is not None else None
|
|
226
|
+
)
|
|
227
|
+
if self.pass_tools_definitions and tools_dicts:
|
|
228
|
+
return renderer.apply_chat_template(
|
|
229
|
+
msgs, tools=tools_dicts, add_generation_prompt=False
|
|
230
|
+
)
|
|
231
|
+
return renderer.apply_chat_template(msgs, add_generation_prompt=False)
|
|
232
|
+
|
|
233
|
+
# ---------------------
|
|
234
|
+
# Inference
|
|
235
|
+
# ---------------------
|
|
236
|
+
def classify_trace(
|
|
237
|
+
self,
|
|
238
|
+
messages: Sequence[dict],
|
|
239
|
+
tools: Sequence[ChatCompletionToolParam] | None = None,
|
|
240
|
+
) -> ClassificationResponse:
|
|
241
|
+
"""POST /classify and parse response into ClassificationResponse."""
|
|
242
|
+
formatted = self.apply_chat_template(messages, tools)
|
|
243
|
+
|
|
244
|
+
def should_retry(exc: BaseException) -> bool:
|
|
245
|
+
# Retry only on 500 Internal Server Error
|
|
246
|
+
if isinstance(exc, httpx.HTTPStatusError):
|
|
247
|
+
return exc.response.status_code == 500
|
|
248
|
+
return False
|
|
249
|
+
|
|
250
|
+
@retry(
|
|
251
|
+
retry=retry_if_exception(should_retry),
|
|
252
|
+
stop=stop_after_attempt(3), # up to 3 tries
|
|
253
|
+
wait=wait_exponential(
|
|
254
|
+
multiplier=1, min=1, max=8
|
|
255
|
+
), # exponential backoff: 1s, 2s, 4s, 8s
|
|
256
|
+
reraise=True, # re-raise the last exception if all retries fail
|
|
257
|
+
)
|
|
258
|
+
def _post_with_retry():
|
|
259
|
+
api_key_value = (
|
|
260
|
+
self.api_key.get_secret_value()
|
|
261
|
+
if isinstance(self.api_key, SecretStr)
|
|
262
|
+
else self.api_key
|
|
263
|
+
)
|
|
264
|
+
resp = self._client.post(
|
|
265
|
+
f"{self.server_url}/classify",
|
|
266
|
+
headers={
|
|
267
|
+
"Content-Type": "application/json",
|
|
268
|
+
"Authorization": f"Bearer {api_key_value}",
|
|
269
|
+
},
|
|
270
|
+
json={"model": self.model_name, "input": formatted},
|
|
271
|
+
timeout=self.timeout_seconds,
|
|
272
|
+
)
|
|
273
|
+
resp.raise_for_status()
|
|
274
|
+
return resp
|
|
275
|
+
|
|
276
|
+
resp = _post_with_retry()
|
|
277
|
+
return ClassificationResponse.model_validate(resp.json())
|
|
278
|
+
|
|
279
|
+
# ---------------------
|
|
280
|
+
# Post-processing helpers
|
|
281
|
+
# ---------------------
|
|
282
|
+
def extract_prob_map(self, response: ClassificationResponse) -> LabelProbMap:
|
|
283
|
+
"""
|
|
284
|
+
Server format (flat-only, strict):
|
|
285
|
+
response.data == [ ClassificationItem(probs=[p0, p1, ..., pN-1],
|
|
286
|
+
num_classes=N) ]
|
|
287
|
+
We align probs directly to self.all_labels (same length, same order).
|
|
288
|
+
"""
|
|
289
|
+
if not response.data:
|
|
290
|
+
raise ValueError("empty response.data from server")
|
|
291
|
+
|
|
292
|
+
item = response.data[0]
|
|
293
|
+
if not item.probs:
|
|
294
|
+
raise ValueError("server returned empty 'probs'")
|
|
295
|
+
if item.num_classes is not None and item.num_classes != len(item.probs):
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"num_classes ({item.num_classes}) does not match "
|
|
298
|
+
f"len(probs) ({len(item.probs)})"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
probs = [float(x) for x in item.probs]
|
|
302
|
+
if len(probs) != len(self.all_labels):
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"len(probs) ({len(probs)}) != len(all_labels) "
|
|
305
|
+
f"({len(self.all_labels)}). "
|
|
306
|
+
"Ensure server label space matches client label space."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
mapping = {lbl: probs[i] for i, lbl in enumerate(self.all_labels)}
|
|
310
|
+
return LabelProbMap(probs=mapping, order=list(self.all_labels))
|
|
311
|
+
|
|
312
|
+
def predict_labels(self, probs: list[float], threshold: float = 0.5) -> list[int]:
|
|
313
|
+
return [1 if p > threshold else 0 for p in probs]
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from openhands.sdk.critic.base import CriticBase, CriticResult
|
|
8
|
+
from openhands.sdk.critic.impl.api.client import CriticClient
|
|
9
|
+
from openhands.sdk.critic.impl.api.taxonomy import categorize_features
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from openhands.sdk.event import LLMConvertibleEvent, SystemPromptEvent
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class APIBasedCritic(CriticBase, CriticClient):
|
|
17
|
+
def evaluate(
|
|
18
|
+
self,
|
|
19
|
+
events: Sequence[LLMConvertibleEvent],
|
|
20
|
+
git_patch: str | None = None, # noqa: ARG002
|
|
21
|
+
) -> CriticResult:
|
|
22
|
+
# Local imports to avoid circular dependencies during module load
|
|
23
|
+
from openhands.sdk.context.view import View
|
|
24
|
+
from openhands.sdk.event import LLMConvertibleEvent, SystemPromptEvent
|
|
25
|
+
|
|
26
|
+
system_prompt_event: SystemPromptEvent | None = None
|
|
27
|
+
tools = []
|
|
28
|
+
for event in events:
|
|
29
|
+
if isinstance(event, SystemPromptEvent):
|
|
30
|
+
system_prompt_event = event
|
|
31
|
+
tools = event.tools
|
|
32
|
+
break
|
|
33
|
+
if system_prompt_event is None:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"SystemPromptEvent is required for APIBasedCritic evaluation"
|
|
36
|
+
)
|
|
37
|
+
if not tools:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
"APIBasedCritic requires tools to be defined in SystemPromptEvent. "
|
|
40
|
+
"Ensure your agent configuration includes tool definitions."
|
|
41
|
+
)
|
|
42
|
+
raise ValueError("Tools are required for APIBasedCritic evaluation")
|
|
43
|
+
|
|
44
|
+
# This will only retain events that are kept by the condenser
|
|
45
|
+
view = View.from_events(events)
|
|
46
|
+
llm_convertible_events = view.events
|
|
47
|
+
|
|
48
|
+
# Convert events to messages
|
|
49
|
+
messages = LLMConvertibleEvent.events_to_messages(llm_convertible_events)
|
|
50
|
+
|
|
51
|
+
# Serialize messages to dicts for API
|
|
52
|
+
for message in messages:
|
|
53
|
+
message.cache_enabled = False
|
|
54
|
+
message.vision_enabled = False # Critic does not support vision currently
|
|
55
|
+
message.function_calling_enabled = True
|
|
56
|
+
message.force_string_serializer = False
|
|
57
|
+
message.send_reasoning_content = False
|
|
58
|
+
formatted_messages = [message.to_chat_dict() for message in messages]
|
|
59
|
+
|
|
60
|
+
# Convert ToolDefinition objects to ChatCompletionToolParam format
|
|
61
|
+
tools_for_api = [tool.to_openai_tool() for tool in tools]
|
|
62
|
+
response = self.classify_trace(formatted_messages, tools_for_api)
|
|
63
|
+
prob_map = self.extract_prob_map(response)
|
|
64
|
+
|
|
65
|
+
explanation = []
|
|
66
|
+
|
|
67
|
+
if "success" not in prob_map.probs:
|
|
68
|
+
raise ValueError("APIBasedCritic requires 'success' label in the response.")
|
|
69
|
+
|
|
70
|
+
score = prob_map.probs["success"]
|
|
71
|
+
explanation.append(f"Success: {score:.2f}")
|
|
72
|
+
|
|
73
|
+
# Add top labels to explanation
|
|
74
|
+
sorted_probs = sorted(prob_map.probs.items(), key=lambda x: x[1], reverse=True)
|
|
75
|
+
explanation.append(json.dumps(dict(sorted_probs)))
|
|
76
|
+
|
|
77
|
+
# Collect event IDs for reproducibility
|
|
78
|
+
event_ids = [event.id for event in llm_convertible_events]
|
|
79
|
+
|
|
80
|
+
# Categorize features for visualization
|
|
81
|
+
categorized = categorize_features(prob_map.probs)
|
|
82
|
+
|
|
83
|
+
return CriticResult(
|
|
84
|
+
score=score,
|
|
85
|
+
message="; ".join(explanation),
|
|
86
|
+
metadata={
|
|
87
|
+
"event_ids": event_ids,
|
|
88
|
+
"categorized_features": categorized,
|
|
89
|
+
},
|
|
90
|
+
)
|