hud-python 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +22 -2
- hud/adapters/claude/adapter.py +9 -2
- hud/adapters/claude/tests/__init__.py +1 -0
- hud/adapters/claude/tests/test_adapter.py +519 -0
- hud/adapters/common/types.py +5 -1
- hud/adapters/operator/adapter.py +4 -0
- hud/adapters/operator/tests/__init__.py +1 -0
- hud/adapters/operator/tests/test_adapter.py +370 -0
- hud/agent/__init__.py +4 -0
- hud/agent/base.py +18 -2
- hud/agent/claude.py +20 -17
- hud/agent/claude_plays_pokemon.py +282 -0
- hud/agent/langchain.py +12 -7
- hud/agent/misc/__init__.py +3 -0
- hud/agent/misc/response_agent.py +80 -0
- hud/agent/operator.py +27 -19
- hud/agent/tests/__init__.py +1 -0
- hud/agent/tests/test_base.py +202 -0
- hud/env/docker_client.py +28 -18
- hud/env/environment.py +32 -16
- hud/env/local_docker_client.py +83 -42
- hud/env/remote_client.py +1 -3
- hud/env/remote_docker_client.py +72 -15
- hud/exceptions.py +12 -0
- hud/gym.py +71 -53
- hud/job.py +52 -7
- hud/settings.py +6 -0
- hud/task.py +45 -33
- hud/taskset.py +44 -4
- hud/telemetry/__init__.py +21 -0
- hud/telemetry/_trace.py +173 -0
- hud/telemetry/context.py +193 -0
- hud/telemetry/exporter.py +417 -0
- hud/telemetry/instrumentation/__init__.py +3 -0
- hud/telemetry/instrumentation/mcp.py +498 -0
- hud/telemetry/instrumentation/registry.py +59 -0
- hud/telemetry/mcp_models.py +331 -0
- hud/telemetry/tests/__init__.py +1 -0
- hud/telemetry/tests/test_context.py +203 -0
- hud/telemetry/tests/test_trace.py +270 -0
- hud/types.py +10 -26
- hud/utils/common.py +22 -2
- hud/utils/misc.py +53 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +7 -0
- {hud_python-0.2.4.dist-info → hud_python-0.2.5.dist-info}/METADATA +90 -22
- hud_python-0.2.5.dist-info/RECORD +84 -0
- hud_python-0.2.4.dist-info/RECORD +0 -62
- {hud_python-0.2.4.dist-info → hud_python-0.2.5.dist-info}/WHEEL +0 -0
- {hud_python-0.2.4.dist-info → hud_python-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
6
|
+
|
|
7
|
+
# Import MCP types
|
|
8
|
+
from mcp.types import JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from mcp.shared.message import SessionMessage
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DirectionType(str, Enum):
|
|
16
|
+
"""Direction of an MCP message"""
|
|
17
|
+
|
|
18
|
+
SENT = "sent"
|
|
19
|
+
RECEIVED = "received"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class StatusType(str, Enum):
|
|
23
|
+
"""Status of an MCP operation"""
|
|
24
|
+
|
|
25
|
+
STARTED = "started"
|
|
26
|
+
COMPLETED = "completed"
|
|
27
|
+
ERROR = "error"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MCPCallType(str, Enum):
|
|
31
|
+
"""Known MCP call types"""
|
|
32
|
+
|
|
33
|
+
SEND_REQUEST = "mcp.shared.session.send_request"
|
|
34
|
+
SEND_NOTIFICATION = "mcp.shared.session.send_notification"
|
|
35
|
+
RECEIVE_RESPONSE = "mcp.shared.session.receive_response"
|
|
36
|
+
RECEIVE_REQUEST = "mcp.shared.session.receive_request"
|
|
37
|
+
STREAM_READ = "mcp.stream.read"
|
|
38
|
+
STREAM_WRITE = "mcp.stream.write"
|
|
39
|
+
HANDLE_INCOMING = "mcp.handle_incoming"
|
|
40
|
+
MANUAL_TEST = "manual.test"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class BaseMCPCall(BaseModel):
|
|
44
|
+
"""Base model for all MCP telemetry records"""
|
|
45
|
+
|
|
46
|
+
task_run_id: str
|
|
47
|
+
call_type: str
|
|
48
|
+
timestamp: float = Field(default_factory=lambda: datetime.now().timestamp())
|
|
49
|
+
method: str = "unknown_method"
|
|
50
|
+
status: StatusType
|
|
51
|
+
direction: DirectionType | None = None
|
|
52
|
+
# Additional data that might be useful for any call
|
|
53
|
+
message_id: str | int | None = None
|
|
54
|
+
|
|
55
|
+
# Mapping of call types to model classes - to be populated by subclasses
|
|
56
|
+
_call_type_mapping: ClassVar[dict[str, type["BaseMCPCall"]]] = {}
|
|
57
|
+
|
|
58
|
+
@field_validator("call_type")
|
|
59
|
+
@classmethod
|
|
60
|
+
def validate_call_type(cls, v: str) -> str:
|
|
61
|
+
"""Allow any string but preferably from MCPCallType"""
|
|
62
|
+
return v
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def from_dict(cls, data: dict[str, Any]) -> BaseMCPCall:
|
|
66
|
+
"""Create a record from a dictionary, using the appropriate subclass"""
|
|
67
|
+
call_type = data.get("call_type", "")
|
|
68
|
+
record_cls = cls._call_type_mapping.get(call_type, BaseMCPCall)
|
|
69
|
+
return record_cls.model_validate(data)
|
|
70
|
+
|
|
71
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
72
|
+
"""Register subclasses in the mapping by their default call_type"""
|
|
73
|
+
super().__init_subclass__(**kwargs)
|
|
74
|
+
if hasattr(cls, "__annotations__") and "call_type" in cls.__annotations__:
|
|
75
|
+
default_call_type = getattr(cls, "call_type", None)
|
|
76
|
+
if isinstance(default_call_type, str):
|
|
77
|
+
BaseMCPCall._call_type_mapping[default_call_type] = cls
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class MCPRequestCall(BaseMCPCall):
|
|
81
|
+
"""Record for an MCP request"""
|
|
82
|
+
|
|
83
|
+
direction: DirectionType = DirectionType.SENT
|
|
84
|
+
call_type: str = MCPCallType.SEND_REQUEST
|
|
85
|
+
start_time: float
|
|
86
|
+
end_time: float | None = None
|
|
87
|
+
duration: float | None = None
|
|
88
|
+
request_id: str | int | None = None
|
|
89
|
+
request_data: dict[str, Any] | None = None
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_jsonrpc_request(
|
|
93
|
+
cls,
|
|
94
|
+
request: JSONRPCRequest,
|
|
95
|
+
task_run_id: str,
|
|
96
|
+
status: StatusType = StatusType.STARTED,
|
|
97
|
+
**kwargs: Any,
|
|
98
|
+
) -> MCPRequestCall:
|
|
99
|
+
"""Create telemetry record from a JSONRPCRequest"""
|
|
100
|
+
return cls(
|
|
101
|
+
task_run_id=task_run_id,
|
|
102
|
+
status=status,
|
|
103
|
+
request_id=request.id,
|
|
104
|
+
message_id=request.id,
|
|
105
|
+
method=request.method,
|
|
106
|
+
request_data=request.model_dump(exclude_none=True),
|
|
107
|
+
start_time=datetime.now().timestamp(),
|
|
108
|
+
**kwargs,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def from_session_message(
|
|
113
|
+
cls,
|
|
114
|
+
message: SessionMessage,
|
|
115
|
+
task_run_id: str,
|
|
116
|
+
status: StatusType = StatusType.STARTED,
|
|
117
|
+
**kwargs: Any,
|
|
118
|
+
) -> MCPRequestCall | None:
|
|
119
|
+
"""Create telemetry record from a SessionMessage containing a JSONRPCRequest"""
|
|
120
|
+
if (
|
|
121
|
+
hasattr(message, "message")
|
|
122
|
+
and hasattr(message.message, "root")
|
|
123
|
+
and isinstance(message.message.root, JSONRPCRequest)
|
|
124
|
+
):
|
|
125
|
+
return cls.from_jsonrpc_request(
|
|
126
|
+
message.message.root, task_run_id=task_run_id, status=status, **kwargs
|
|
127
|
+
)
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class MCPResponseCall(BaseMCPCall):
|
|
132
|
+
"""Record for an MCP response"""
|
|
133
|
+
|
|
134
|
+
direction: DirectionType = DirectionType.RECEIVED
|
|
135
|
+
call_type: str = MCPCallType.RECEIVE_RESPONSE
|
|
136
|
+
is_response_or_error: bool = True
|
|
137
|
+
is_error: bool = False
|
|
138
|
+
response_id: str | int | None = None
|
|
139
|
+
related_request_id: str | int | None = None
|
|
140
|
+
response_data: dict[str, Any] | None = None
|
|
141
|
+
error: str | None = None
|
|
142
|
+
error_type: str | None = None
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def from_jsonrpc_response(
|
|
146
|
+
cls, response: JSONRPCResponse | JSONRPCError, task_run_id: str, **kwargs: Any
|
|
147
|
+
) -> MCPResponseCall:
|
|
148
|
+
"""Create telemetry record from a JSONRPCResponse or JSONRPCError"""
|
|
149
|
+
is_error = isinstance(response, JSONRPCError)
|
|
150
|
+
|
|
151
|
+
result = cls(
|
|
152
|
+
task_run_id=task_run_id,
|
|
153
|
+
status=StatusType.COMPLETED,
|
|
154
|
+
response_id=response.id,
|
|
155
|
+
message_id=response.id,
|
|
156
|
+
related_request_id=response.id, # In MCP, response ID matches request ID
|
|
157
|
+
is_error=is_error,
|
|
158
|
+
method=f"response_to_id_{response.id}",
|
|
159
|
+
response_data=response.model_dump(exclude_none=True),
|
|
160
|
+
**kwargs,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if is_error and hasattr(response, "error"):
|
|
164
|
+
result.error = response.error.message
|
|
165
|
+
result.error_type = str(response.error.code)
|
|
166
|
+
|
|
167
|
+
return result
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def from_session_message(
|
|
171
|
+
cls, message: SessionMessage, task_run_id: str, **kwargs: Any
|
|
172
|
+
) -> MCPResponseCall | None:
|
|
173
|
+
"""Create telemetry record from a SessionMessage containing a response or error"""
|
|
174
|
+
if (
|
|
175
|
+
hasattr(message, "message")
|
|
176
|
+
and hasattr(message.message, "root")
|
|
177
|
+
and isinstance(message.message.root, JSONRPCResponse | JSONRPCError)
|
|
178
|
+
):
|
|
179
|
+
return cls.from_jsonrpc_response(
|
|
180
|
+
message.message.root, task_run_id=task_run_id, **kwargs
|
|
181
|
+
)
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class MCPNotificationCall(BaseMCPCall):
|
|
186
|
+
"""Record for an MCP notification"""
|
|
187
|
+
|
|
188
|
+
direction: DirectionType = DirectionType.SENT
|
|
189
|
+
call_type: str = MCPCallType.SEND_NOTIFICATION
|
|
190
|
+
start_time: float
|
|
191
|
+
end_time: float | None = None
|
|
192
|
+
duration: float | None = None
|
|
193
|
+
notification_data: dict[str, Any] | None = None
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
def from_jsonrpc_notification(
|
|
197
|
+
cls,
|
|
198
|
+
notification: JSONRPCNotification,
|
|
199
|
+
task_run_id: str,
|
|
200
|
+
status: StatusType = StatusType.STARTED,
|
|
201
|
+
**kwargs: Any,
|
|
202
|
+
) -> MCPNotificationCall:
|
|
203
|
+
"""Create telemetry record from a JSONRPCNotification"""
|
|
204
|
+
return cls(
|
|
205
|
+
task_run_id=task_run_id,
|
|
206
|
+
status=status,
|
|
207
|
+
method=notification.method,
|
|
208
|
+
notification_data=notification.model_dump(exclude_none=True),
|
|
209
|
+
start_time=datetime.now().timestamp(),
|
|
210
|
+
**kwargs,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
@classmethod
|
|
214
|
+
def from_session_message(
|
|
215
|
+
cls,
|
|
216
|
+
message: SessionMessage,
|
|
217
|
+
task_run_id: str,
|
|
218
|
+
status: StatusType = StatusType.STARTED,
|
|
219
|
+
**kwargs: Any,
|
|
220
|
+
) -> MCPNotificationCall | None:
|
|
221
|
+
"""Create telemetry record from a SessionMessage containing a JSONRPCNotification"""
|
|
222
|
+
if (
|
|
223
|
+
hasattr(message, "message")
|
|
224
|
+
and hasattr(message.message, "root")
|
|
225
|
+
and isinstance(message.message.root, JSONRPCNotification)
|
|
226
|
+
):
|
|
227
|
+
return cls.from_jsonrpc_notification(
|
|
228
|
+
message.message.root, task_run_id=task_run_id, status=status, **kwargs
|
|
229
|
+
)
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class MCPStreamEvent(BaseMCPCall):
|
|
234
|
+
"""Record for an MCP stream event (read or write)"""
|
|
235
|
+
|
|
236
|
+
stream_event: bool = True
|
|
237
|
+
event_type: str = Field(..., description="Type of stream event: read or write")
|
|
238
|
+
item_type: str | None = None
|
|
239
|
+
is_response_or_error: bool = False
|
|
240
|
+
message_data: dict[str, Any] | None = None
|
|
241
|
+
|
|
242
|
+
@classmethod
|
|
243
|
+
def from_session_message(
|
|
244
|
+
cls, message: SessionMessage, task_run_id: str, event_type: str, **kwargs: Any
|
|
245
|
+
) -> MCPStreamEvent:
|
|
246
|
+
"""Create telemetry record for a stream event"""
|
|
247
|
+
method_name = "unknown_stream_operation"
|
|
248
|
+
is_response = False
|
|
249
|
+
item_type = "unknown"
|
|
250
|
+
message_data = None
|
|
251
|
+
|
|
252
|
+
if hasattr(message, "message") and hasattr(message.message, "root"):
|
|
253
|
+
msg_root = message.message.root
|
|
254
|
+
item_type = type(msg_root).__name__
|
|
255
|
+
message_data = msg_root.model_dump(exclude_none=True)
|
|
256
|
+
|
|
257
|
+
# Check type first before accessing attributes
|
|
258
|
+
if isinstance(msg_root, JSONRPCRequest | JSONRPCNotification) and hasattr(
|
|
259
|
+
msg_root, "method"
|
|
260
|
+
):
|
|
261
|
+
method_name = msg_root.method
|
|
262
|
+
elif isinstance(msg_root, JSONRPCResponse | JSONRPCError) and hasattr(msg_root, "id"):
|
|
263
|
+
method_name = f"response_to_id_{msg_root.id}"
|
|
264
|
+
is_response = True
|
|
265
|
+
|
|
266
|
+
return cls(
|
|
267
|
+
task_run_id=task_run_id,
|
|
268
|
+
status=StatusType.COMPLETED,
|
|
269
|
+
method=method_name,
|
|
270
|
+
event_type=event_type,
|
|
271
|
+
item_type=item_type,
|
|
272
|
+
is_response_or_error=is_response,
|
|
273
|
+
message_data=message_data,
|
|
274
|
+
timestamp=datetime.now().timestamp(),
|
|
275
|
+
**kwargs,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class MCPManualTestCall(BaseMCPCall):
|
|
280
|
+
"""Record for a manual test record"""
|
|
281
|
+
|
|
282
|
+
call_type: str = MCPCallType.MANUAL_TEST
|
|
283
|
+
custom_data: dict[str, Any] = Field(default_factory=dict)
|
|
284
|
+
|
|
285
|
+
@classmethod
|
|
286
|
+
def create(cls, task_run_id: str, **custom_data: Any) -> MCPManualTestCall:
|
|
287
|
+
"""Create a manual test record with custom data"""
|
|
288
|
+
return cls(
|
|
289
|
+
task_run_id=task_run_id,
|
|
290
|
+
status=StatusType.COMPLETED,
|
|
291
|
+
custom_data=custom_data,
|
|
292
|
+
timestamp=datetime.now().timestamp(),
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class MCPTelemetryRecord(BaseModel):
|
|
297
|
+
"""Container for a set of related MCP telemetry records"""
|
|
298
|
+
|
|
299
|
+
task_run_id: str
|
|
300
|
+
records: list[BaseMCPCall]
|
|
301
|
+
timestamp: float = Field(default_factory=lambda: datetime.now().timestamp())
|
|
302
|
+
|
|
303
|
+
@property
|
|
304
|
+
def count_by_type(self) -> dict[str, int]:
|
|
305
|
+
"""Count records by call_type"""
|
|
306
|
+
result: dict[str, int] = {}
|
|
307
|
+
for record in self.records:
|
|
308
|
+
result[record.call_type] = result.get(record.call_type, 0) + 1
|
|
309
|
+
return result
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def count_by_direction(self) -> dict[str, int]:
|
|
313
|
+
"""Count records by direction"""
|
|
314
|
+
result: dict[str, int] = {}
|
|
315
|
+
for record in self.records:
|
|
316
|
+
if record.direction:
|
|
317
|
+
direction = record.direction.value
|
|
318
|
+
result[direction] = result.get(direction, 0) + 1
|
|
319
|
+
return result
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class TrajectoryStep(BaseModel):
|
|
323
|
+
"""Model representing a single step in a trajectory, for export."""
|
|
324
|
+
|
|
325
|
+
type: str = Field(default="mcp-step") # Default for MCP calls
|
|
326
|
+
observation_url: str | None = None
|
|
327
|
+
observation_text: str | None = None
|
|
328
|
+
actions: list[dict[str, Any]] = Field(default_factory=list)
|
|
329
|
+
start_timestamp: str | None = None # ISO 8601 format
|
|
330
|
+
end_timestamp: str | None = None # ISO 8601 format
|
|
331
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Tests for hud.telemetry module
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from unittest.mock import MagicMock
|
|
4
|
+
|
|
5
|
+
from hud.telemetry.context import (
|
|
6
|
+
buffer_mcp_call,
|
|
7
|
+
flush_buffer,
|
|
8
|
+
get_current_task_run_id,
|
|
9
|
+
is_root_trace,
|
|
10
|
+
set_current_task_run_id,
|
|
11
|
+
)
|
|
12
|
+
from hud.telemetry.mcp_models import BaseMCPCall
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestTaskRunIdContext:
|
|
16
|
+
"""Test task run ID context management."""
|
|
17
|
+
|
|
18
|
+
def test_get_current_task_run_id_initial(self):
|
|
19
|
+
"""Test getting task run ID when none is set."""
|
|
20
|
+
# Reset context for clean test
|
|
21
|
+
set_current_task_run_id(None)
|
|
22
|
+
result = get_current_task_run_id()
|
|
23
|
+
assert result is None
|
|
24
|
+
|
|
25
|
+
def test_set_and_get_task_run_id(self):
|
|
26
|
+
"""Test setting and getting task run ID."""
|
|
27
|
+
test_id = "test-task-run-id"
|
|
28
|
+
set_current_task_run_id(test_id)
|
|
29
|
+
result = get_current_task_run_id()
|
|
30
|
+
assert result == test_id
|
|
31
|
+
|
|
32
|
+
def test_task_run_id_isolation(self):
|
|
33
|
+
"""Test that task run IDs are isolated per context."""
|
|
34
|
+
# This test simulates what would happen in different contexts
|
|
35
|
+
set_current_task_run_id("context-1")
|
|
36
|
+
assert get_current_task_run_id() == "context-1"
|
|
37
|
+
|
|
38
|
+
set_current_task_run_id("context-2")
|
|
39
|
+
assert get_current_task_run_id() == "context-2"
|
|
40
|
+
|
|
41
|
+
# Reset to None
|
|
42
|
+
set_current_task_run_id(None)
|
|
43
|
+
assert get_current_task_run_id() is None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TestRootTraceContext:
|
|
47
|
+
"""Test root trace context management."""
|
|
48
|
+
|
|
49
|
+
def test_is_root_trace_initial(self):
|
|
50
|
+
"""Test is_root_trace initial state."""
|
|
51
|
+
# The initial state may vary, so we just test that it returns a boolean
|
|
52
|
+
result = is_root_trace.get()
|
|
53
|
+
assert isinstance(result, bool)
|
|
54
|
+
|
|
55
|
+
def test_set_root_trace(self):
|
|
56
|
+
"""Test setting root trace state."""
|
|
57
|
+
is_root_trace.set(True)
|
|
58
|
+
assert is_root_trace.get() is True
|
|
59
|
+
|
|
60
|
+
is_root_trace.set(False)
|
|
61
|
+
assert is_root_trace.get() is False
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TestMCPCallBuffer:
|
|
65
|
+
"""Test MCP call buffer management."""
|
|
66
|
+
|
|
67
|
+
def setUp(self):
|
|
68
|
+
"""Clear buffer before each test."""
|
|
69
|
+
# Flush any existing calls and reset context
|
|
70
|
+
flush_buffer()
|
|
71
|
+
set_current_task_run_id(None)
|
|
72
|
+
|
|
73
|
+
def test_flush_buffer_empty(self):
|
|
74
|
+
"""Test flushing empty buffer."""
|
|
75
|
+
self.setUp()
|
|
76
|
+
result = flush_buffer()
|
|
77
|
+
assert result == []
|
|
78
|
+
|
|
79
|
+
def test_add_and_flush_mcp_call(self):
|
|
80
|
+
"""Test adding and flushing MCP calls."""
|
|
81
|
+
self.setUp()
|
|
82
|
+
|
|
83
|
+
# Set active task run ID
|
|
84
|
+
set_current_task_run_id("test-task")
|
|
85
|
+
|
|
86
|
+
# Create mock MCP call with required attributes
|
|
87
|
+
mock_call = MagicMock(spec=BaseMCPCall)
|
|
88
|
+
mock_call.model_dump.return_value = {"type": "test", "task_run_id": "test-task"}
|
|
89
|
+
mock_call.task_run_id = "test-task"
|
|
90
|
+
|
|
91
|
+
buffer_mcp_call(mock_call)
|
|
92
|
+
|
|
93
|
+
# Flush should return the call and clear buffer
|
|
94
|
+
result = flush_buffer()
|
|
95
|
+
assert len(result) == 1
|
|
96
|
+
assert result[0] == mock_call
|
|
97
|
+
|
|
98
|
+
# Buffer should be empty after flush
|
|
99
|
+
result2 = flush_buffer()
|
|
100
|
+
assert result2 == []
|
|
101
|
+
|
|
102
|
+
def test_add_multiple_mcp_calls(self):
|
|
103
|
+
"""Test adding multiple MCP calls."""
|
|
104
|
+
self.setUp()
|
|
105
|
+
|
|
106
|
+
# Set active task run ID
|
|
107
|
+
set_current_task_run_id("test-task")
|
|
108
|
+
|
|
109
|
+
# Create multiple mock calls
|
|
110
|
+
mock_calls = []
|
|
111
|
+
for i in range(3):
|
|
112
|
+
mock_call = MagicMock(spec=BaseMCPCall)
|
|
113
|
+
mock_call.model_dump.return_value = {"type": f"test_{i}", "task_run_id": "test-task"}
|
|
114
|
+
mock_call.task_run_id = "test-task"
|
|
115
|
+
mock_calls.append(mock_call)
|
|
116
|
+
buffer_mcp_call(mock_call)
|
|
117
|
+
|
|
118
|
+
# Flush should return all calls
|
|
119
|
+
result = flush_buffer()
|
|
120
|
+
assert len(result) == 3
|
|
121
|
+
assert result == mock_calls
|
|
122
|
+
|
|
123
|
+
def test_buffer_isolation_per_task(self):
|
|
124
|
+
"""Test that MCP call buffers contain all calls regardless of task ID."""
|
|
125
|
+
self.setUp()
|
|
126
|
+
|
|
127
|
+
# Set task run ID 1
|
|
128
|
+
set_current_task_run_id("task-1")
|
|
129
|
+
mock_call_1 = MagicMock(spec=BaseMCPCall)
|
|
130
|
+
mock_call_1.task_run_id = "task-1"
|
|
131
|
+
mock_call_1.model_dump.return_value = {"type": "test", "task_run_id": "task-1"}
|
|
132
|
+
buffer_mcp_call(mock_call_1)
|
|
133
|
+
|
|
134
|
+
# Set task run ID 2
|
|
135
|
+
set_current_task_run_id("task-2")
|
|
136
|
+
mock_call_2 = MagicMock(spec=BaseMCPCall)
|
|
137
|
+
mock_call_2.task_run_id = "task-2"
|
|
138
|
+
mock_call_2.model_dump.return_value = {"type": "test", "task_run_id": "task-2"}
|
|
139
|
+
buffer_mcp_call(mock_call_2)
|
|
140
|
+
|
|
141
|
+
# Flush should return all calls from both tasks
|
|
142
|
+
result = flush_buffer()
|
|
143
|
+
assert len(result) == 2
|
|
144
|
+
assert result[0] == mock_call_1
|
|
145
|
+
assert result[1] == mock_call_2
|
|
146
|
+
|
|
147
|
+
def test_buffer_mcp_call_without_task_id(self):
|
|
148
|
+
"""Test adding MCP call when no task run ID is set."""
|
|
149
|
+
self.setUp()
|
|
150
|
+
set_current_task_run_id(None)
|
|
151
|
+
|
|
152
|
+
mock_call = MagicMock(spec=BaseMCPCall)
|
|
153
|
+
mock_call.task_run_id = None
|
|
154
|
+
buffer_mcp_call(mock_call)
|
|
155
|
+
|
|
156
|
+
# Should not buffer anything when no task ID is set
|
|
157
|
+
result = flush_buffer()
|
|
158
|
+
assert len(result) == 0
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class TestContextIntegration:
|
|
162
|
+
"""Integration tests for context management."""
|
|
163
|
+
|
|
164
|
+
def test_context_lifecycle(self):
|
|
165
|
+
"""Test complete context lifecycle."""
|
|
166
|
+
# Start with clean state
|
|
167
|
+
set_current_task_run_id(None)
|
|
168
|
+
flush_buffer()
|
|
169
|
+
is_root_trace.set(False)
|
|
170
|
+
|
|
171
|
+
# Set up trace context
|
|
172
|
+
task_id = "integration-test-task"
|
|
173
|
+
set_current_task_run_id(task_id)
|
|
174
|
+
is_root_trace.set(True)
|
|
175
|
+
|
|
176
|
+
# Add some MCP calls
|
|
177
|
+
mock_calls = []
|
|
178
|
+
for i in range(2):
|
|
179
|
+
mock_call = MagicMock(spec=BaseMCPCall)
|
|
180
|
+
mock_call.model_dump.return_value = {
|
|
181
|
+
"type": f"integration_test_{i}",
|
|
182
|
+
"task_run_id": task_id,
|
|
183
|
+
}
|
|
184
|
+
mock_call.task_run_id = task_id
|
|
185
|
+
mock_calls.append(mock_call)
|
|
186
|
+
buffer_mcp_call(mock_call)
|
|
187
|
+
|
|
188
|
+
# Verify context state
|
|
189
|
+
assert get_current_task_run_id() == task_id
|
|
190
|
+
assert is_root_trace.get() is True
|
|
191
|
+
|
|
192
|
+
# Flush and verify
|
|
193
|
+
result = flush_buffer()
|
|
194
|
+
assert len(result) == 2
|
|
195
|
+
assert result == mock_calls
|
|
196
|
+
|
|
197
|
+
# Clean up
|
|
198
|
+
set_current_task_run_id(None)
|
|
199
|
+
is_root_trace.set(False)
|
|
200
|
+
|
|
201
|
+
# Verify cleanup
|
|
202
|
+
assert get_current_task_run_id() is None
|
|
203
|
+
assert flush_buffer() == []
|