openhands-sdk 1.8.1__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openhands/sdk/agent/agent.py +64 -0
- openhands/sdk/agent/base.py +29 -10
- openhands/sdk/agent/prompts/system_prompt.j2 +1 -0
- openhands/sdk/context/condenser/llm_summarizing_condenser.py +7 -5
- openhands/sdk/context/skills/skill.py +59 -1
- openhands/sdk/context/skills/utils.py +6 -65
- openhands/sdk/context/view.py +6 -11
- openhands/sdk/conversation/base.py +5 -0
- openhands/sdk/conversation/event_store.py +84 -12
- openhands/sdk/conversation/impl/local_conversation.py +7 -0
- openhands/sdk/conversation/impl/remote_conversation.py +16 -3
- openhands/sdk/conversation/state.py +25 -2
- openhands/sdk/conversation/visualizer/base.py +23 -0
- openhands/sdk/critic/__init__.py +4 -1
- openhands/sdk/critic/base.py +17 -20
- openhands/sdk/critic/impl/__init__.py +2 -0
- openhands/sdk/critic/impl/agent_finished.py +9 -5
- openhands/sdk/critic/impl/api/__init__.py +18 -0
- openhands/sdk/critic/impl/api/chat_template.py +232 -0
- openhands/sdk/critic/impl/api/client.py +313 -0
- openhands/sdk/critic/impl/api/critic.py +90 -0
- openhands/sdk/critic/impl/api/taxonomy.py +180 -0
- openhands/sdk/critic/result.py +148 -0
- openhands/sdk/event/conversation_error.py +12 -0
- openhands/sdk/event/llm_convertible/action.py +10 -0
- openhands/sdk/event/llm_convertible/message.py +10 -0
- openhands/sdk/git/cached_repo.py +459 -0
- openhands/sdk/git/utils.py +118 -3
- openhands/sdk/hooks/__init__.py +7 -1
- openhands/sdk/hooks/config.py +154 -45
- openhands/sdk/io/base.py +52 -0
- openhands/sdk/io/local.py +25 -0
- openhands/sdk/io/memory.py +34 -1
- openhands/sdk/llm/llm.py +6 -2
- openhands/sdk/llm/utils/model_features.py +3 -0
- openhands/sdk/llm/utils/telemetry.py +41 -2
- openhands/sdk/plugin/__init__.py +17 -0
- openhands/sdk/plugin/fetch.py +231 -0
- openhands/sdk/plugin/plugin.py +61 -4
- openhands/sdk/plugin/types.py +394 -1
- openhands/sdk/secret/secrets.py +19 -4
- {openhands_sdk-1.8.1.dist-info → openhands_sdk-1.9.0.dist-info}/METADATA +6 -1
- {openhands_sdk-1.8.1.dist-info → openhands_sdk-1.9.0.dist-info}/RECORD +45 -37
- {openhands_sdk-1.8.1.dist-info → openhands_sdk-1.9.0.dist-info}/WHEEL +1 -1
- {openhands_sdk-1.8.1.dist-info → openhands_sdk-1.9.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from litellm import ChatCompletionToolParam
|
|
7
|
+
from pydantic import (
|
|
8
|
+
BaseModel,
|
|
9
|
+
ConfigDict,
|
|
10
|
+
Field,
|
|
11
|
+
PrivateAttr,
|
|
12
|
+
SecretStr,
|
|
13
|
+
field_validator,
|
|
14
|
+
)
|
|
15
|
+
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
|
16
|
+
|
|
17
|
+
from .chat_template import ChatTemplateRenderer
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ============================================================
|
|
21
|
+
# Typed API response models
|
|
22
|
+
# ============================================================
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class UsageTokens(BaseModel):
|
|
26
|
+
prompt_tokens: int | None = None
|
|
27
|
+
total_tokens: int | None = None
|
|
28
|
+
completion_tokens: int | None = None
|
|
29
|
+
prompt_tokens_details: dict | None = None
|
|
30
|
+
model_config = ConfigDict(extra="allow")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ClassificationItem(BaseModel):
|
|
34
|
+
"""One per-label or flat classification result."""
|
|
35
|
+
|
|
36
|
+
index: int | None = None
|
|
37
|
+
label: str | None = None
|
|
38
|
+
probs: list[float]
|
|
39
|
+
num_classes: int | None = None
|
|
40
|
+
model_config = ConfigDict(extra="allow")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ClassificationResponse(BaseModel):
|
|
44
|
+
id: str | None = None
|
|
45
|
+
object: str | None = None
|
|
46
|
+
created: int | None = None
|
|
47
|
+
model: str | None = None
|
|
48
|
+
data: list[ClassificationItem] = Field(default_factory=list)
|
|
49
|
+
usage: UsageTokens | None = None
|
|
50
|
+
model_config = ConfigDict(extra="allow")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LabelProbMap(BaseModel):
|
|
54
|
+
"""Normalized probability map label -> value, with optional ordering."""
|
|
55
|
+
|
|
56
|
+
probs: dict[str, float] # {"label": probability}
|
|
57
|
+
order: list[str] | None = None # if you requested a specific order
|
|
58
|
+
model_config = ConfigDict(extra="forbid")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# ============================================================
|
|
62
|
+
# CriticClient
|
|
63
|
+
# ============================================================
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class CriticClient(BaseModel):
|
|
67
|
+
"""
|
|
68
|
+
Core inference client for the Critic classification service.
|
|
69
|
+
|
|
70
|
+
Owns:
|
|
71
|
+
- Configuration (server URL, API key, model, tokenizer, etc.)
|
|
72
|
+
- Label space (for predictions only)
|
|
73
|
+
- Message normalization and chat template formatting
|
|
74
|
+
- Inference via vLLM /classify endpoint
|
|
75
|
+
|
|
76
|
+
Does NOT handle:
|
|
77
|
+
- Dataset loading
|
|
78
|
+
- Ground truth extraction
|
|
79
|
+
- Evaluation / metrics
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
|
|
83
|
+
|
|
84
|
+
# --- connection / model config ---
|
|
85
|
+
server_url: str = Field(
|
|
86
|
+
default="https://all-hands-ai--critic-qwen3-4b-serve.modal.run",
|
|
87
|
+
description="Base URL of the vLLM classification service",
|
|
88
|
+
)
|
|
89
|
+
api_key: str | SecretStr = Field(
|
|
90
|
+
..., description="API key for authenticating with the vLLM service"
|
|
91
|
+
)
|
|
92
|
+
model_name: str = Field(
|
|
93
|
+
default="critic-qwen3-4b", description="Name of the model to use"
|
|
94
|
+
)
|
|
95
|
+
tokenizer_name: str = Field(
|
|
96
|
+
default="Qwen/Qwen3-4B-Instruct-2507",
|
|
97
|
+
description="HuggingFace tokenizer name for loading chat template",
|
|
98
|
+
)
|
|
99
|
+
pass_tools_definitions: bool = Field(
|
|
100
|
+
default=True, description="Whether to pass tool definitions to the model"
|
|
101
|
+
)
|
|
102
|
+
timeout_seconds: float = Field(
|
|
103
|
+
default=300.0, description="Timeout for requests to the model"
|
|
104
|
+
)
|
|
105
|
+
has_success_label: bool = Field(
|
|
106
|
+
default=True, description="Whether the model predicts success label at index 0"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# --- runtime fields ---
|
|
110
|
+
_client: httpx.Client = PrivateAttr(default_factory=httpx.Client)
|
|
111
|
+
_template_renderer: ChatTemplateRenderer | None = PrivateAttr(default=None)
|
|
112
|
+
|
|
113
|
+
# --- label space ---
|
|
114
|
+
sentiment_labels: tuple[str, ...] = (
|
|
115
|
+
"sentiment_positive",
|
|
116
|
+
"sentiment_neutral",
|
|
117
|
+
"sentiment_negative",
|
|
118
|
+
)
|
|
119
|
+
agent_issue_labels: tuple[str, ...] = (
|
|
120
|
+
"misunderstood_intention",
|
|
121
|
+
"did_not_follow_instruction",
|
|
122
|
+
"insufficient_analysis",
|
|
123
|
+
"insufficient_clarification",
|
|
124
|
+
"improper_tool_use_or_setup",
|
|
125
|
+
"loop_behavior",
|
|
126
|
+
"insufficient_testing",
|
|
127
|
+
"insufficient_debugging",
|
|
128
|
+
"incomplete_implementation",
|
|
129
|
+
"file_management_errors",
|
|
130
|
+
"scope_creep",
|
|
131
|
+
"risky_actions_or_permission",
|
|
132
|
+
"other_agent_issue",
|
|
133
|
+
)
|
|
134
|
+
infra_labels: tuple[str, ...] = (
|
|
135
|
+
"infrastructure_external_issue",
|
|
136
|
+
"infrastructure_agent_caused_issue",
|
|
137
|
+
)
|
|
138
|
+
user_followup_labels: tuple[str, ...] = (
|
|
139
|
+
"clarification_or_restatement",
|
|
140
|
+
"correction",
|
|
141
|
+
"direction_change",
|
|
142
|
+
"vcs_update_requests",
|
|
143
|
+
"progress_or_scope_concern",
|
|
144
|
+
"frustration_or_complaint",
|
|
145
|
+
"removal_or_reversion_request",
|
|
146
|
+
"other_user_issue",
|
|
147
|
+
)
|
|
148
|
+
sentiment_map: dict[str, str] = {
|
|
149
|
+
"Positive": "sentiment_positive",
|
|
150
|
+
"Neutral": "sentiment_neutral",
|
|
151
|
+
"Negative": "sentiment_negative",
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
# ---------------------
|
|
155
|
+
# Validation
|
|
156
|
+
# ---------------------
|
|
157
|
+
@field_validator("api_key", mode="before")
|
|
158
|
+
@classmethod
|
|
159
|
+
def _validate_and_convert_api_key(cls, v: str | SecretStr) -> SecretStr:
|
|
160
|
+
"""Convert str to SecretStr and validate non-empty."""
|
|
161
|
+
if isinstance(v, SecretStr):
|
|
162
|
+
secret_value = v.get_secret_value()
|
|
163
|
+
else:
|
|
164
|
+
secret_value = v
|
|
165
|
+
|
|
166
|
+
if not secret_value or not secret_value.strip():
|
|
167
|
+
raise ValueError("api_key must be non-empty")
|
|
168
|
+
|
|
169
|
+
return SecretStr(secret_value) if isinstance(v, str) else v
|
|
170
|
+
|
|
171
|
+
# ---------------------
|
|
172
|
+
# Label helpers
|
|
173
|
+
# ---------------------
|
|
174
|
+
@property
|
|
175
|
+
def all_labels(self) -> tuple[str, ...]:
|
|
176
|
+
base_labels = (
|
|
177
|
+
self.sentiment_labels
|
|
178
|
+
+ self.agent_issue_labels
|
|
179
|
+
+ self.infra_labels
|
|
180
|
+
+ self.user_followup_labels
|
|
181
|
+
)
|
|
182
|
+
if self.has_success_label:
|
|
183
|
+
return ("success",) + base_labels
|
|
184
|
+
return base_labels
|
|
185
|
+
|
|
186
|
+
# ---------------------
|
|
187
|
+
# Tokenizer / formatting
|
|
188
|
+
# ---------------------
|
|
189
|
+
def _get_template_renderer(self) -> ChatTemplateRenderer:
|
|
190
|
+
"""Lazily initialize the chat template renderer."""
|
|
191
|
+
if self._template_renderer is None:
|
|
192
|
+
self._template_renderer = ChatTemplateRenderer(
|
|
193
|
+
tokenizer_name=self.tokenizer_name
|
|
194
|
+
)
|
|
195
|
+
return self._template_renderer
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def normalize_messages(messages: Sequence[dict]) -> Sequence[dict]:
|
|
199
|
+
"""Ensure messages all have string content and flatten text blocks."""
|
|
200
|
+
out: list[dict] = []
|
|
201
|
+
for msg in messages or []:
|
|
202
|
+
content = msg.get("content", "") or ""
|
|
203
|
+
if isinstance(content, list):
|
|
204
|
+
text_parts = [
|
|
205
|
+
block.get("text", "")
|
|
206
|
+
for block in content
|
|
207
|
+
if isinstance(block, dict) and block.get("type") == "text"
|
|
208
|
+
]
|
|
209
|
+
content = "\n".join(text_parts)
|
|
210
|
+
if not isinstance(content, str):
|
|
211
|
+
content = str(content)
|
|
212
|
+
out.append({"role": msg.get("role", ""), "content": content})
|
|
213
|
+
return out
|
|
214
|
+
|
|
215
|
+
def apply_chat_template(
|
|
216
|
+
self,
|
|
217
|
+
messages: Sequence[dict],
|
|
218
|
+
tools: Sequence[ChatCompletionToolParam] | None = None,
|
|
219
|
+
) -> str:
|
|
220
|
+
renderer = self._get_template_renderer()
|
|
221
|
+
msgs = self.normalize_messages(copy.deepcopy(messages))
|
|
222
|
+
# Cast tools to Sequence[dict[str, Any]] for type compatibility
|
|
223
|
+
# ChatCompletionToolParam is a TypedDict which is structurally compatible
|
|
224
|
+
tools_dicts: Sequence[dict[str, Any]] | None = (
|
|
225
|
+
cast(Sequence[dict[str, Any]], tools) if tools is not None else None
|
|
226
|
+
)
|
|
227
|
+
if self.pass_tools_definitions and tools_dicts:
|
|
228
|
+
return renderer.apply_chat_template(
|
|
229
|
+
msgs, tools=tools_dicts, add_generation_prompt=False
|
|
230
|
+
)
|
|
231
|
+
return renderer.apply_chat_template(msgs, add_generation_prompt=False)
|
|
232
|
+
|
|
233
|
+
# ---------------------
|
|
234
|
+
# Inference
|
|
235
|
+
# ---------------------
|
|
236
|
+
def classify_trace(
|
|
237
|
+
self,
|
|
238
|
+
messages: Sequence[dict],
|
|
239
|
+
tools: Sequence[ChatCompletionToolParam] | None = None,
|
|
240
|
+
) -> ClassificationResponse:
|
|
241
|
+
"""POST /classify and parse response into ClassificationResponse."""
|
|
242
|
+
formatted = self.apply_chat_template(messages, tools)
|
|
243
|
+
|
|
244
|
+
def should_retry(exc: BaseException) -> bool:
|
|
245
|
+
# Retry only on 500 Internal Server Error
|
|
246
|
+
if isinstance(exc, httpx.HTTPStatusError):
|
|
247
|
+
return exc.response.status_code == 500
|
|
248
|
+
return False
|
|
249
|
+
|
|
250
|
+
@retry(
|
|
251
|
+
retry=retry_if_exception(should_retry),
|
|
252
|
+
stop=stop_after_attempt(3), # up to 3 tries
|
|
253
|
+
wait=wait_exponential(
|
|
254
|
+
multiplier=1, min=1, max=8
|
|
255
|
+
), # exponential backoff: 1s, 2s, 4s, 8s
|
|
256
|
+
reraise=True, # re-raise the last exception if all retries fail
|
|
257
|
+
)
|
|
258
|
+
def _post_with_retry():
|
|
259
|
+
api_key_value = (
|
|
260
|
+
self.api_key.get_secret_value()
|
|
261
|
+
if isinstance(self.api_key, SecretStr)
|
|
262
|
+
else self.api_key
|
|
263
|
+
)
|
|
264
|
+
resp = self._client.post(
|
|
265
|
+
f"{self.server_url}/classify",
|
|
266
|
+
headers={
|
|
267
|
+
"Content-Type": "application/json",
|
|
268
|
+
"Authorization": f"Bearer {api_key_value}",
|
|
269
|
+
},
|
|
270
|
+
json={"model": self.model_name, "input": formatted},
|
|
271
|
+
timeout=self.timeout_seconds,
|
|
272
|
+
)
|
|
273
|
+
resp.raise_for_status()
|
|
274
|
+
return resp
|
|
275
|
+
|
|
276
|
+
resp = _post_with_retry()
|
|
277
|
+
return ClassificationResponse.model_validate(resp.json())
|
|
278
|
+
|
|
279
|
+
# ---------------------
|
|
280
|
+
# Post-processing helpers
|
|
281
|
+
# ---------------------
|
|
282
|
+
def extract_prob_map(self, response: ClassificationResponse) -> LabelProbMap:
|
|
283
|
+
"""
|
|
284
|
+
Server format (flat-only, strict):
|
|
285
|
+
response.data == [ ClassificationItem(probs=[p0, p1, ..., pN-1],
|
|
286
|
+
num_classes=N) ]
|
|
287
|
+
We align probs directly to self.all_labels (same length, same order).
|
|
288
|
+
"""
|
|
289
|
+
if not response.data:
|
|
290
|
+
raise ValueError("empty response.data from server")
|
|
291
|
+
|
|
292
|
+
item = response.data[0]
|
|
293
|
+
if not item.probs:
|
|
294
|
+
raise ValueError("server returned empty 'probs'")
|
|
295
|
+
if item.num_classes is not None and item.num_classes != len(item.probs):
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"num_classes ({item.num_classes}) does not match "
|
|
298
|
+
f"len(probs) ({len(item.probs)})"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
probs = [float(x) for x in item.probs]
|
|
302
|
+
if len(probs) != len(self.all_labels):
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"len(probs) ({len(probs)}) != len(all_labels) "
|
|
305
|
+
f"({len(self.all_labels)}). "
|
|
306
|
+
"Ensure server label space matches client label space."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
mapping = {lbl: probs[i] for i, lbl in enumerate(self.all_labels)}
|
|
310
|
+
return LabelProbMap(probs=mapping, order=list(self.all_labels))
|
|
311
|
+
|
|
312
|
+
def predict_labels(self, probs: list[float], threshold: float = 0.5) -> list[int]:
|
|
313
|
+
return [1 if p > threshold else 0 for p in probs]
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from openhands.sdk.critic.base import CriticBase, CriticResult
|
|
8
|
+
from openhands.sdk.critic.impl.api.client import CriticClient
|
|
9
|
+
from openhands.sdk.critic.impl.api.taxonomy import categorize_features
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from openhands.sdk.event import LLMConvertibleEvent, SystemPromptEvent
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class APIBasedCritic(CriticBase, CriticClient):
|
|
17
|
+
def evaluate(
|
|
18
|
+
self,
|
|
19
|
+
events: Sequence[LLMConvertibleEvent],
|
|
20
|
+
git_patch: str | None = None, # noqa: ARG002
|
|
21
|
+
) -> CriticResult:
|
|
22
|
+
# Local imports to avoid circular dependencies during module load
|
|
23
|
+
from openhands.sdk.context.view import View
|
|
24
|
+
from openhands.sdk.event import LLMConvertibleEvent, SystemPromptEvent
|
|
25
|
+
|
|
26
|
+
system_prompt_event: SystemPromptEvent | None = None
|
|
27
|
+
tools = []
|
|
28
|
+
for event in events:
|
|
29
|
+
if isinstance(event, SystemPromptEvent):
|
|
30
|
+
system_prompt_event = event
|
|
31
|
+
tools = event.tools
|
|
32
|
+
break
|
|
33
|
+
if system_prompt_event is None:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"SystemPromptEvent is required for APIBasedCritic evaluation"
|
|
36
|
+
)
|
|
37
|
+
if not tools:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
"APIBasedCritic requires tools to be defined in SystemPromptEvent. "
|
|
40
|
+
"Ensure your agent configuration includes tool definitions."
|
|
41
|
+
)
|
|
42
|
+
raise ValueError("Tools are required for APIBasedCritic evaluation")
|
|
43
|
+
|
|
44
|
+
# This will only retain events that are kept by the condenser
|
|
45
|
+
view = View.from_events(events)
|
|
46
|
+
llm_convertible_events = view.events
|
|
47
|
+
|
|
48
|
+
# Convert events to messages
|
|
49
|
+
messages = LLMConvertibleEvent.events_to_messages(llm_convertible_events)
|
|
50
|
+
|
|
51
|
+
# Serialize messages to dicts for API
|
|
52
|
+
for message in messages:
|
|
53
|
+
message.cache_enabled = False
|
|
54
|
+
message.vision_enabled = False # Critic does not support vision currently
|
|
55
|
+
message.function_calling_enabled = True
|
|
56
|
+
message.force_string_serializer = False
|
|
57
|
+
message.send_reasoning_content = False
|
|
58
|
+
formatted_messages = [message.to_chat_dict() for message in messages]
|
|
59
|
+
|
|
60
|
+
# Convert ToolDefinition objects to ChatCompletionToolParam format
|
|
61
|
+
tools_for_api = [tool.to_openai_tool() for tool in tools]
|
|
62
|
+
response = self.classify_trace(formatted_messages, tools_for_api)
|
|
63
|
+
prob_map = self.extract_prob_map(response)
|
|
64
|
+
|
|
65
|
+
explanation = []
|
|
66
|
+
|
|
67
|
+
if "success" not in prob_map.probs:
|
|
68
|
+
raise ValueError("APIBasedCritic requires 'success' label in the response.")
|
|
69
|
+
|
|
70
|
+
score = prob_map.probs["success"]
|
|
71
|
+
explanation.append(f"Success: {score:.2f}")
|
|
72
|
+
|
|
73
|
+
# Add top labels to explanation
|
|
74
|
+
sorted_probs = sorted(prob_map.probs.items(), key=lambda x: x[1], reverse=True)
|
|
75
|
+
explanation.append(json.dumps(dict(sorted_probs)))
|
|
76
|
+
|
|
77
|
+
# Collect event IDs for reproducibility
|
|
78
|
+
event_ids = [event.id for event in llm_convertible_events]
|
|
79
|
+
|
|
80
|
+
# Categorize features for visualization
|
|
81
|
+
categorized = categorize_features(prob_map.probs)
|
|
82
|
+
|
|
83
|
+
return CriticResult(
|
|
84
|
+
score=score,
|
|
85
|
+
message="; ".join(explanation),
|
|
86
|
+
metadata={
|
|
87
|
+
"event_ids": event_ids,
|
|
88
|
+
"categorized_features": categorized,
|
|
89
|
+
},
|
|
90
|
+
)
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Critic taxonomy - mapping of features to categories for visualization."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# Feature to category mapping
|
|
8
|
+
FEATURE_CATEGORIES: dict[str, str] = {
|
|
9
|
+
# General Context & Task Classification
|
|
10
|
+
"user_goal_summary": "general_context",
|
|
11
|
+
"overall_sentiment": "general_context",
|
|
12
|
+
# Agent Behavioral Issues
|
|
13
|
+
"misunderstood_intention": "agent_behavioral_issues",
|
|
14
|
+
"did_not_follow_instruction": "agent_behavioral_issues",
|
|
15
|
+
"insufficient_analysis": "agent_behavioral_issues",
|
|
16
|
+
"insufficient_clarification": "agent_behavioral_issues",
|
|
17
|
+
"improper_tool_use_or_setup": "agent_behavioral_issues",
|
|
18
|
+
"loop_behavior": "agent_behavioral_issues",
|
|
19
|
+
"insufficient_testing": "agent_behavioral_issues",
|
|
20
|
+
"insufficient_debugging": "agent_behavioral_issues",
|
|
21
|
+
"incomplete_implementation": "agent_behavioral_issues",
|
|
22
|
+
"file_management_errors": "agent_behavioral_issues",
|
|
23
|
+
"scope_creep": "agent_behavioral_issues",
|
|
24
|
+
"risky_actions_or_permission": "agent_behavioral_issues",
|
|
25
|
+
"other_agent_issue": "agent_behavioral_issues",
|
|
26
|
+
# User Follow-Up Patterns
|
|
27
|
+
"follow_up_timing": "user_followup_patterns",
|
|
28
|
+
"clarification_or_restatement": "user_followup_patterns",
|
|
29
|
+
"correction": "user_followup_patterns",
|
|
30
|
+
"direction_change": "user_followup_patterns",
|
|
31
|
+
"vcs_update_requests": "user_followup_patterns",
|
|
32
|
+
"progress_or_scope_concern": "user_followup_patterns",
|
|
33
|
+
"frustration_or_complaint": "user_followup_patterns",
|
|
34
|
+
"removal_or_reversion_request": "user_followup_patterns",
|
|
35
|
+
"other_user_issue": "user_followup_patterns",
|
|
36
|
+
# Infrastructure Issues
|
|
37
|
+
"infrastructure_external_issue": "infrastructure_issues",
|
|
38
|
+
"infrastructure_agent_caused_issue": "infrastructure_issues",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
# Category display names for visualization
|
|
42
|
+
CATEGORY_DISPLAY_NAMES: dict[str, str] = {
|
|
43
|
+
"general_context": "General Context",
|
|
44
|
+
"agent_behavioral_issues": "Detected Agent Behavioral Issues",
|
|
45
|
+
"user_followup_patterns": "Predicted User Follow-Up Patterns",
|
|
46
|
+
"infrastructure_issues": "Detected Infrastructure Issues",
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_category(feature_name: str) -> str | None:
|
|
51
|
+
"""Get the category for a feature.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
feature_name: Name of the feature
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Category name or None if not found
|
|
58
|
+
"""
|
|
59
|
+
return FEATURE_CATEGORIES.get(feature_name)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _softmax_normalize(probs: dict[str, float]) -> dict[str, float]:
|
|
63
|
+
"""Apply softmax normalization to convert logits to probabilities.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
probs: Dictionary of names to raw probability/logit values
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Dictionary with softmax-normalized probabilities that sum to 1.0
|
|
70
|
+
"""
|
|
71
|
+
if not probs:
|
|
72
|
+
return {}
|
|
73
|
+
|
|
74
|
+
values = list(probs.values())
|
|
75
|
+
exp_values = [math.exp(v) for v in values]
|
|
76
|
+
exp_sum = sum(exp_values)
|
|
77
|
+
normalized = [exp_v / exp_sum for exp_v in exp_values]
|
|
78
|
+
|
|
79
|
+
return dict(zip(probs.keys(), normalized))
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def categorize_features(
|
|
83
|
+
probs_dict: dict[str, float],
|
|
84
|
+
display_threshold: float = 0.2,
|
|
85
|
+
) -> dict[str, Any]:
|
|
86
|
+
"""Categorize features from probability dictionary into taxonomy groups.
|
|
87
|
+
|
|
88
|
+
This function takes raw probability outputs from the critic model and
|
|
89
|
+
organizes them into categories ready for visualization.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
probs_dict: Dictionary of feature names to probability values
|
|
93
|
+
display_threshold: Minimum probability to include a feature (default: 0.2)
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Dictionary with categorized features ready for visualization:
|
|
97
|
+
{
|
|
98
|
+
"sentiment": {
|
|
99
|
+
"predicted": "Neutral",
|
|
100
|
+
"probability": 0.77,
|
|
101
|
+
"all": {"positive": 0.10, "neutral": 0.77, "negative": 0.13}
|
|
102
|
+
},
|
|
103
|
+
"agent_behavioral_issues": [
|
|
104
|
+
{"name": "loop_behavior", "display_name": "Loop Behavior",
|
|
105
|
+
"probability": 0.85},
|
|
106
|
+
...
|
|
107
|
+
],
|
|
108
|
+
"user_followup_patterns": [...],
|
|
109
|
+
"infrastructure_issues": [...],
|
|
110
|
+
"other": [...]
|
|
111
|
+
}
|
|
112
|
+
"""
|
|
113
|
+
result: dict[str, Any] = {
|
|
114
|
+
"sentiment": None,
|
|
115
|
+
"agent_behavioral_issues": [],
|
|
116
|
+
"user_followup_patterns": [],
|
|
117
|
+
"infrastructure_issues": [],
|
|
118
|
+
"other": [],
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
# Extract sentiment features and apply softmax normalization
|
|
122
|
+
raw_sentiment_probs = {}
|
|
123
|
+
for feature_name, prob in probs_dict.items():
|
|
124
|
+
if feature_name.startswith("sentiment_"):
|
|
125
|
+
short_name = feature_name.replace("sentiment_", "")
|
|
126
|
+
raw_sentiment_probs[short_name] = prob
|
|
127
|
+
|
|
128
|
+
if raw_sentiment_probs:
|
|
129
|
+
# Apply softmax normalization to convert logits to probabilities
|
|
130
|
+
sentiment_probs = _softmax_normalize(raw_sentiment_probs)
|
|
131
|
+
max_sentiment = max(sentiment_probs.items(), key=lambda x: x[1])
|
|
132
|
+
result["sentiment"] = {
|
|
133
|
+
"predicted": max_sentiment[0].capitalize(),
|
|
134
|
+
"probability": max_sentiment[1],
|
|
135
|
+
"all": sentiment_probs,
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
# Categorize other features
|
|
139
|
+
for feature_name, prob in probs_dict.items():
|
|
140
|
+
# Skip sentiment features (already processed)
|
|
141
|
+
if feature_name.startswith("sentiment_"):
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
# Skip 'success' as it's redundant with the score
|
|
145
|
+
if feature_name == "success":
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
# Skip features below threshold
|
|
149
|
+
if prob < display_threshold:
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
category = FEATURE_CATEGORIES.get(feature_name)
|
|
153
|
+
feature_entry = {
|
|
154
|
+
"name": feature_name,
|
|
155
|
+
"display_name": feature_name.replace("_", " ").title(),
|
|
156
|
+
"probability": prob,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if category == "general_context":
|
|
160
|
+
# Skip general context features for now
|
|
161
|
+
continue
|
|
162
|
+
elif category == "agent_behavioral_issues":
|
|
163
|
+
result["agent_behavioral_issues"].append(feature_entry)
|
|
164
|
+
elif category == "user_followup_patterns":
|
|
165
|
+
result["user_followup_patterns"].append(feature_entry)
|
|
166
|
+
elif category == "infrastructure_issues":
|
|
167
|
+
result["infrastructure_issues"].append(feature_entry)
|
|
168
|
+
else:
|
|
169
|
+
result["other"].append(feature_entry)
|
|
170
|
+
|
|
171
|
+
# Sort each category by probability (descending)
|
|
172
|
+
for key in [
|
|
173
|
+
"agent_behavioral_issues",
|
|
174
|
+
"user_followup_patterns",
|
|
175
|
+
"infrastructure_issues",
|
|
176
|
+
"other",
|
|
177
|
+
]:
|
|
178
|
+
result[key] = sorted(result[key], key=lambda x: x["probability"], reverse=True)
|
|
179
|
+
|
|
180
|
+
return result
|